Dandandan commented on a change in pull request #1539: URL: https://github.com/apache/arrow-datafusion/pull/1539#discussion_r782529863
########## File path: datafusion/src/physical_plan/tdigest/mod.rs ########## @@ -0,0 +1,818 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with this +// work for additional information regarding copyright ownership. The ASF +// licenses this file to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +//! An implementation of the [TDigest sketch algorithm] providing approximate +//! quantile calculations. +//! +//! The TDigest code in this module is modified from +//! https://github.com/MnO2/t-digest, itself a rust reimplementation of +//! [Facebook's Folly TDigest] implementation. +//! +//! Alterations include reduction of runtime heap allocations, broader type +//! support, (de-)serialisation support, reduced type conversions and null value +//! tolerance. +//! +//! [TDigest sketch algorithm]: https://arxiv.org/abs/1902.04023 +//! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h + +use arrow::datatypes::DataType; +use ordered_float::OrderedFloat; +use std::cmp::Ordering; + +use crate::{ + error::{DataFusionError, Result}, + scalar::ScalarValue, +}; + +// Cast a non-null [`ScalarValue::Float64`] to an [`OrderedFloat<f64>`], or +// panic. +macro_rules! cast_scalar_f64 { + ($value:expr ) => { + match &$value { + ScalarValue::Float64(Some(v)) => OrderedFloat::from(*v), + v => panic!("invalid type {:?}", v), + } + }; +} + +/// This trait is implemented for each type a [`TDigest`] can operate on, +/// allowing it to support both numerical rust types (obtained from +/// `PrimitiveArray` instances), and [`ScalarValue`] instances. +pub(crate) trait TryIntoOrderedF64 { + /// A fallible conversion of a possibly null `self` into a [`OrderedFloat<f64>`]. + /// + /// If `self` is null, this method must return `Ok(None)`. + /// + /// If `self` cannot be coerced to the desired type, this method must return + /// an `Err` variant. + fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>>; +} + +/// Generate an infallible conversion from `type` to an [`OrderedFloat<f64>`]. +macro_rules! impl_try_ordered_f64 { + ($type:ty) => { + impl TryIntoOrderedF64 for $type { + fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>> { + Ok(Some(OrderedFloat::from(*self as f64))) + } + } + }; +} + +impl_try_ordered_f64!(f64); +impl_try_ordered_f64!(f32); +impl_try_ordered_f64!(i64); +impl_try_ordered_f64!(i32); +impl_try_ordered_f64!(i16); +impl_try_ordered_f64!(i8); +impl_try_ordered_f64!(u64); +impl_try_ordered_f64!(u32); +impl_try_ordered_f64!(u16); +impl_try_ordered_f64!(u8); + +impl TryIntoOrderedF64 for ScalarValue { + fn try_as_f64(&self) -> Result<Option<OrderedFloat<f64>>> { + match self { + ScalarValue::Float32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Float64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::Int64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt8(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt16(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt32(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + ScalarValue::UInt64(v) => Ok(v.map(|v| OrderedFloat::from(v as f64))), + + got => { + return Err(DataFusionError::NotImplemented(format!( + "Support for 'APPROX_QUANTILE' for data type {} is not implemented", + got + ))) + } + } + } +} + +/// Centroid implementation to the cluster mentioned in the paper. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct Centroid { + mean: OrderedFloat<f64>, + weight: OrderedFloat<f64>, +} + +impl PartialOrd for Centroid { + fn partial_cmp(&self, other: &Centroid) -> Option<Ordering> { + Some(self.cmp(other)) + } +} + +impl Ord for Centroid { + fn cmp(&self, other: &Centroid) -> Ordering { + self.mean.cmp(&other.mean) + } +} + +impl Centroid { + pub(crate) fn new( + mean: impl Into<OrderedFloat<f64>>, + weight: impl Into<OrderedFloat<f64>>, + ) -> Self { + Centroid { + mean: mean.into(), + weight: weight.into(), + } + } + + #[inline] + pub(crate) fn mean(&self) -> OrderedFloat<f64> { + self.mean + } + + #[inline] + pub(crate) fn weight(&self) -> OrderedFloat<f64> { + self.weight + } + + pub(crate) fn add( + &mut self, + sum: impl Into<OrderedFloat<f64>>, + weight: impl Into<OrderedFloat<f64>>, + ) -> f64 { + let new_sum = sum.into() + self.weight * self.mean; + let new_weight = self.weight + weight.into(); + self.weight = new_weight; + self.mean = new_sum / new_weight; + new_sum.into_inner() + } +} + +impl Default for Centroid { + fn default() -> Self { + Centroid { + mean: OrderedFloat::from(0.0), + weight: OrderedFloat::from(1.0), + } + } +} + +/// T-Digest to be operated on. +#[derive(Debug, PartialEq, Eq, Clone)] +pub(crate) struct TDigest { + centroids: Vec<Centroid>, + max_size: usize, + sum: OrderedFloat<f64>, + count: OrderedFloat<f64>, + max: OrderedFloat<f64>, + min: OrderedFloat<f64>, +} + +impl TDigest { + pub(crate) fn new(max_size: usize) -> Self { + TDigest { + centroids: Vec::new(), + max_size, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } + + #[inline] + pub(crate) fn count(&self) -> f64 { + self.count.into_inner() + } + + #[inline] + pub(crate) fn max(&self) -> f64 { + self.max.into_inner() + } + + #[inline] + pub(crate) fn min(&self) -> f64 { + self.min.into_inner() + } + + #[inline] + pub(crate) fn max_size(&self) -> usize { + self.max_size + } +} + +impl Default for TDigest { + fn default() -> Self { + TDigest { + centroids: Vec::new(), + max_size: 100, + sum: OrderedFloat::from(0.0), + count: OrderedFloat::from(0.0), + max: OrderedFloat::from(std::f64::NAN), + min: OrderedFloat::from(std::f64::NAN), + } + } +} + +impl TDigest { + fn k_to_q(k: f64, d: f64) -> OrderedFloat<f64> { + let k_div_d = k / d; + if k_div_d >= 0.5 { + let base = 1.0 - k_div_d; + 1.0 - 2.0 * base * base + } else { + 2.0 * k_div_d * k_div_d + } + .into() + } + + fn clamp( + v: OrderedFloat<f64>, + lo: OrderedFloat<f64>, + hi: OrderedFloat<f64>, + ) -> OrderedFloat<f64> { + if v > hi { + hi + } else if v < lo { + lo + } else { + v + } + } + + pub(crate) fn merge_unsorted<T: TryIntoOrderedF64>( + &self, + unsorted_values: impl IntoIterator<Item = T>, + ) -> Result<TDigest> { + let mut values = unsorted_values + .into_iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::<Result<Vec<_>>>()?; + + values.sort(); + + Ok(self.merge_sorted_f64(&values)) + } + + fn merge_sorted_f64(&self, sorted_values: &[OrderedFloat<f64>]) -> TDigest { + debug_assert!(is_sorted(&sorted_values), "unsorted input to TDigest"); + + if sorted_values.is_empty() { + return self.clone(); + } + + let mut result = TDigest::new(self.max_size()); + result.count = OrderedFloat::from(self.count() + (sorted_values.len() as f64)); + + let maybe_min = OrderedFloat::from(*sorted_values.first().unwrap()); + let maybe_max = OrderedFloat::from(*sorted_values.last().unwrap()); + + if self.count() > 0.0 { + result.min = std::cmp::min(self.min, maybe_min); + result.max = std::cmp::max(self.max, maybe_max); + } else { + result.min = maybe_min; + result.max = maybe_max; + } + + let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size); + + let mut k_limit: f64 = 1.0; + let mut q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + + let mut iter_centroids = self.centroids.iter().peekable(); + let mut iter_sorted_values = sorted_values.iter().peekable(); + + let mut curr: Centroid = if let Some(c) = iter_centroids.peek() { + let curr = **iter_sorted_values.peek().unwrap(); + if c.mean() < curr { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let mut weight_so_far = curr.weight(); + + let mut sums_to_merge = OrderedFloat::from(0.0); + let mut weights_to_merge = OrderedFloat::from(0.0); + + while iter_centroids.peek().is_some() || iter_sorted_values.peek().is_some() { + let next: Centroid = if let Some(c) = iter_centroids.peek() { + if iter_sorted_values.peek().is_none() + || c.mean() < **iter_sorted_values.peek().unwrap() + { + iter_centroids.next().unwrap().clone() + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + } + } else { + Centroid::new(*iter_sorted_values.next().unwrap(), 1.0) + }; + + let next_sum = next.mean() * next.weight(); + weight_so_far += next.weight(); + + if weight_so_far <= q_limit_times_count { + sums_to_merge += next_sum; + weights_to_merge += next.weight(); + } else { + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + sums_to_merge = 0.0.into(); + weights_to_merge = 0.0.into(); + + compressed.push(curr.clone()); + q_limit_times_count = + Self::k_to_q(k_limit, self.max_size as f64) * result.count(); + k_limit += 1.0; + curr = next; + } + } + + result.sum = OrderedFloat::from( + result.sum.into_inner() + curr.add(sums_to_merge, weights_to_merge), + ); + compressed.push(curr); + compressed.shrink_to_fit(); + compressed.sort(); + + result.centroids = compressed; + result + } + + fn external_merge( + centroids: &mut Vec<Centroid>, + first: usize, + middle: usize, + last: usize, + ) { + let mut result: Vec<Centroid> = Vec::with_capacity(centroids.len()); + + let mut i = first; + let mut j = middle; + + while i < middle && j < last { + match centroids[i].cmp(¢roids[j]) { + Ordering::Less => { + result.push(centroids[i].clone()); + i += 1; + } + Ordering::Greater => { + result.push(centroids[j].clone()); + j += 1; + } + Ordering::Equal => { + result.push(centroids[i].clone()); Review comment: The branches are the same for less and equal so I think the whole match could be simplified to ```rust if centroids[i] <= centroids[j] { result.push(centroids[i].clone()); i += 1; } else { result.push(centroids[j].clone()); j += 1; } -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
