This is an automated email from the ASF dual-hosted git repository.

yangjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 74b966ec34 Improve ApproxPercentileAccumulator merge api and fix bug 
(#10056)
74b966ec34 is described below

commit 74b966ec347df9c56d3061c2349d7c3188966997
Author: Yang Jiang <[email protected]>
AuthorDate: Tue Apr 16 13:40:37 2024 +0800

    Improve ApproxPercentileAccumulator merge api and fix bug (#10056)
    
    * improve ApproxPercentileAccumulator merge api and fix bug
    
    * add test for accumulator merge_digests
    
    * fix test
    
    * Reduce cloneing in ApproxPercentileAccumulator
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 .../src/aggregate/approx_percentile_cont.rs        | 37 ++++++++++++++++++++--
 datafusion/physical-expr/src/aggregate/tdigest.rs  |  5 ++-
 2 files changed, 38 insertions(+), 4 deletions(-)

diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs 
b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
index 3dbf1679e2..63a4c85f9e 100644
--- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
+++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
@@ -34,7 +34,7 @@ use datafusion_common::{
     ScalarValue,
 };
 use datafusion_expr::{Accumulator, ColumnarValue};
-use std::{any::Any, iter, sync::Arc};
+use std::{any::Any, sync::Arc};
 
 /// APPROX_PERCENTILE_CONT aggregate expression
 #[derive(Debug)]
@@ -284,7 +284,8 @@ impl ApproxPercentileAccumulator {
     }
 
     pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
-        self.digest = TDigest::merge_digests(digests);
+        let digests = digests.iter().chain(std::iter::once(&self.digest));
+        self.digest = TDigest::merge_digests(digests)
     }
 
     pub(crate) fn convert_to_float(values: &ArrayRef) -> Result<Vec<f64>> {
@@ -425,7 +426,6 @@ impl Accumulator for ApproxPercentileAccumulator {
                     .collect::<Result<Vec<_>>>()
                     .map(|state| TDigest::from_scalar_state(&state))
             })
-            .chain(iter::once(Ok(self.digest.clone())))
             .collect::<Result<Vec<_>>>()?;
 
         self.merge_digests(&states);
@@ -440,3 +440,34 @@ impl Accumulator for ApproxPercentileAccumulator {
             - std::mem::size_of_val(&self.return_type)
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator;
+    use crate::aggregate::tdigest::TDigest;
+    use arrow_schema::DataType;
+
+    #[test]
+    fn test_combine_approx_percentile_accumulator() {
+        let mut digests: Vec<TDigest> = Vec::new();
+
+        // one TDigest with 50_000 values from 1 to 1_000
+        for _ in 1..=50 {
+            let t = TDigest::new(100);
+            let values: Vec<_> = (1..=1_000).map(f64::from).collect();
+            let t = t.merge_unsorted_f64(values);
+            digests.push(t)
+        }
+
+        let t1 = TDigest::merge_digests(&digests);
+        let t2 = TDigest::merge_digests(&digests);
+
+        let mut accumulator =
+            ApproxPercentileAccumulator::new_with_max_size(0.5, 
DataType::Float64, 100);
+
+        accumulator.merge_digests(&[t1]);
+        assert_eq!(accumulator.digest.count(), 50_000.0);
+        accumulator.merge_digests(&[t2]);
+        assert_eq!(accumulator.digest.count(), 100_000.0);
+    }
+}
diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs 
b/datafusion/physical-expr/src/aggregate/tdigest.rs
index 78708df94c..e3b23b91d0 100644
--- a/datafusion/physical-expr/src/aggregate/tdigest.rs
+++ b/datafusion/physical-expr/src/aggregate/tdigest.rs
@@ -370,7 +370,10 @@ impl TDigest {
     }
 
     // Merge multiple T-Digests
-    pub(crate) fn merge_digests(digests: &[TDigest]) -> TDigest {
+    pub(crate) fn merge_digests<'a>(
+        digests: impl IntoIterator<Item = &'a TDigest>,
+    ) -> TDigest {
+        let digests = digests.into_iter().collect::<Vec<_>>();
         let n_centroids: usize = digests.iter().map(|d| 
d.centroids.len()).sum();
         if n_centroids == 0 {
             return TDigest::default();

Reply via email to