This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a0ad376840 [Minor] Refactor approx_percentile (#11769)
a0ad376840 is described below

commit a0ad376840daac8fdfecee5a4988c585350c629b
Author: DaniĆ«l Heres <[email protected]>
AuthorDate: Fri Aug 2 02:47:27 2024 +0200

    [Minor] Refactor approx_percentile (#11769)
    
    * Refactor approx_percentile
    
    * Refactor approx_percentile
    
    * Types
    
    * Types
    
    * Types
---
 .../functions-aggregate/src/approx_median.rs       |  2 +-
 .../src/approx_percentile_cont.rs                  |  8 +--
 .../physical-expr-common/src/aggregate/tdigest.rs  | 62 +++++++++++++---------
 3 files changed, 41 insertions(+), 31 deletions(-)

diff --git a/datafusion/functions-aggregate/src/approx_median.rs 
b/datafusion/functions-aggregate/src/approx_median.rs
index e12e3445a8..c386ad89f0 100644
--- a/datafusion/functions-aggregate/src/approx_median.rs
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian {
         Ok(vec![
             Field::new(format_state_name(args.name, "max_size"), UInt64, 
false),
             Field::new(format_state_name(args.name, "sum"), Float64, false),
-            Field::new(format_state_name(args.name, "count"), Float64, false),
+            Field::new(format_state_name(args.name, "count"), UInt64, false),
             Field::new(format_state_name(args.name, "max"), Float64, false),
             Field::new(format_state_name(args.name, "min"), Float64, false),
             Field::new_list(
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs 
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
index 844e48f0a4..af2a26fd05 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -214,7 +214,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
             ),
             Field::new(
                 format_state_name(args.name, "count"),
-                DataType::Float64,
+                DataType::UInt64,
                 false,
             ),
             Field::new(
@@ -406,7 +406,7 @@ impl Accumulator for ApproxPercentileAccumulator {
     }
 
     fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
-        if self.digest.count() == 0.0 {
+        if self.digest.count() == 0 {
             return ScalarValue::try_from(self.return_type.clone());
         }
         let q = self.digest.estimate_quantile(self.percentile);
@@ -487,8 +487,8 @@ mod tests {
             ApproxPercentileAccumulator::new_with_max_size(0.5, 
DataType::Float64, 100);
 
         accumulator.merge_digests(&[t1]);
-        assert_eq!(accumulator.digest.count(), 50_000.0);
+        assert_eq!(accumulator.digest.count(), 50_000);
         accumulator.merge_digests(&[t2]);
-        assert_eq!(accumulator.digest.count(), 100_000.0);
+        assert_eq!(accumulator.digest.count(), 100_000);
     }
 }
diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs 
b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
index 1da3d7180d..070ebc4648 100644
--- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs
+++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
@@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 {
     };
 }
 
+// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
+// panic.
+macro_rules! cast_scalar_u64 {
+    ($value:expr ) => {
+        match &$value {
+            ScalarValue::UInt64(Some(v)) => *v,
+            v => panic!("invalid type {:?}", v),
+        }
+    };
+}
+
 /// This trait is implemented for each type a [`TDigest`] can operate on,
 /// allowing it to support both numerical rust types (obtained from
 /// `PrimitiveArray` instances), and [`ScalarValue`] instances.
@@ -142,7 +153,7 @@ pub struct TDigest {
     centroids: Vec<Centroid>,
     max_size: usize,
     sum: f64,
-    count: f64,
+    count: u64,
     max: f64,
     min: f64,
 }
@@ -153,7 +164,7 @@ impl TDigest {
             centroids: Vec::new(),
             max_size,
             sum: 0_f64,
-            count: 0_f64,
+            count: 0,
             max: f64::NAN,
             min: f64::NAN,
         }
@@ -164,14 +175,14 @@ impl TDigest {
             centroids: vec![centroid.clone()],
             max_size,
             sum: centroid.mean * centroid.weight,
-            count: 1_f64,
+            count: 1,
             max: centroid.mean,
             min: centroid.mean,
         }
     }
 
     #[inline]
-    pub fn count(&self) -> f64 {
+    pub fn count(&self) -> u64 {
         self.count
     }
 
@@ -203,7 +214,7 @@ impl Default for TDigest {
             centroids: Vec::new(),
             max_size: 100,
             sum: 0_f64,
-            count: 0_f64,
+            count: 0,
             max: f64::NAN,
             min: f64::NAN,
         }
@@ -211,8 +222,8 @@ impl Default for TDigest {
 }
 
 impl TDigest {
-    fn k_to_q(k: f64, d: f64) -> f64 {
-        let k_div_d = k / d;
+    fn k_to_q(k: u64, d: usize) -> f64 {
+        let k_div_d = k as f64 / d as f64;
         if k_div_d >= 0.5 {
             let base = 1.0 - k_div_d;
             1.0 - 2.0 * base * base
@@ -244,12 +255,12 @@ impl TDigest {
         }
 
         let mut result = TDigest::new(self.max_size());
-        result.count = self.count() + (sorted_values.len() as f64);
+        result.count = self.count() + sorted_values.len() as u64;
 
         let maybe_min = *sorted_values.first().unwrap();
         let maybe_max = *sorted_values.last().unwrap();
 
-        if self.count() > 0.0 {
+        if self.count() > 0 {
             result.min = self.min.min(maybe_min);
             result.max = self.max.max(maybe_max);
         } else {
@@ -259,10 +270,10 @@ impl TDigest {
 
         let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
 
-        let mut k_limit: f64 = 1.0;
+        let mut k_limit: u64 = 1;
         let mut q_limit_times_count =
-            Self::k_to_q(k_limit, self.max_size as f64) * result.count();
-        k_limit += 1.0;
+            Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
+        k_limit += 1;
 
         let mut iter_centroids = self.centroids.iter().peekable();
         let mut iter_sorted_values = sorted_values.iter().peekable();
@@ -309,8 +320,8 @@ impl TDigest {
 
                 compressed.push(curr.clone());
                 q_limit_times_count =
-                    Self::k_to_q(k_limit, self.max_size as f64) * 
result.count();
-                k_limit += 1.0;
+                    Self::k_to_q(k_limit, self.max_size) * result.count() as 
f64;
+                k_limit += 1;
                 curr = next;
             }
         }
@@ -381,7 +392,7 @@ impl TDigest {
         let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
         let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
 
-        let mut count: f64 = 0.0;
+        let mut count = 0;
         let mut min = f64::INFINITY;
         let mut max = f64::NEG_INFINITY;
 
@@ -389,8 +400,8 @@ impl TDigest {
         for digest in digests.iter() {
             starts.push(start);
 
-            let curr_count: f64 = digest.count();
-            if curr_count > 0.0 {
+            let curr_count = digest.count();
+            if curr_count > 0 {
                 min = min.min(digest.min);
                 max = max.max(digest.max);
                 count += curr_count;
@@ -424,8 +435,8 @@ impl TDigest {
         let mut result = TDigest::new(max_size);
         let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
 
-        let mut k_limit: f64 = 1.0;
-        let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * 
(count);
+        let mut k_limit = 1;
+        let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count 
as f64;
 
         let mut iter_centroids = centroids.iter_mut();
         let mut curr = iter_centroids.next().unwrap();
@@ -444,8 +455,8 @@ impl TDigest {
                 sums_to_merge = 0_f64;
                 weights_to_merge = 0_f64;
                 compressed.push(curr.clone());
-                q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) * 
(count);
-                k_limit += 1.0;
+                q_limit_times_count = Self::k_to_q(k_limit, max_size) * count 
as f64;
+                k_limit += 1;
                 curr = centroid;
             }
         }
@@ -468,8 +479,7 @@ impl TDigest {
             return 0.0;
         }
 
-        let count_ = self.count;
-        let rank = q * count_;
+        let rank = q * self.count as f64;
 
         let mut pos: usize;
         let mut t;
@@ -479,7 +489,7 @@ impl TDigest {
             }
 
             pos = 0;
-            t = count_;
+            t = self.count as f64;
 
             for (k, centroid) in self.centroids.iter().enumerate().rev() {
                 t -= centroid.weight();
@@ -581,7 +591,7 @@ impl TDigest {
         vec![
             ScalarValue::UInt64(Some(self.max_size as u64)),
             ScalarValue::Float64(Some(self.sum)),
-            ScalarValue::Float64(Some(self.count)),
+            ScalarValue::UInt64(Some(self.count)),
             ScalarValue::Float64(Some(self.max)),
             ScalarValue::Float64(Some(self.min)),
             ScalarValue::List(arr),
@@ -627,7 +637,7 @@ impl TDigest {
         Self {
             max_size,
             sum: cast_scalar_f64!(state[1]),
-            count: cast_scalar_f64!(&state[2]),
+            count: cast_scalar_u64!(&state[2]),
             max,
             min,
             centroids,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to