This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a0ad376840 [Minor] Refactor approx_percentile (#11769)
a0ad376840 is described below
commit a0ad376840daac8fdfecee5a4988c585350c629b
Author: Daniƫl Heres <[email protected]>
AuthorDate: Fri Aug 2 02:47:27 2024 +0200
[Minor] Refactor approx_percentile (#11769)
* Refactor approx_percentile
* Refactor approx_percentile
* Types
* Types
* Types
---
.../functions-aggregate/src/approx_median.rs | 2 +-
.../src/approx_percentile_cont.rs | 8 +--
.../physical-expr-common/src/aggregate/tdigest.rs | 62 +++++++++++++---------
3 files changed, 41 insertions(+), 31 deletions(-)
diff --git a/datafusion/functions-aggregate/src/approx_median.rs
b/datafusion/functions-aggregate/src/approx_median.rs
index e12e3445a8..c386ad89f0 100644
--- a/datafusion/functions-aggregate/src/approx_median.rs
+++ b/datafusion/functions-aggregate/src/approx_median.rs
@@ -78,7 +78,7 @@ impl AggregateUDFImpl for ApproxMedian {
Ok(vec![
Field::new(format_state_name(args.name, "max_size"), UInt64,
false),
Field::new(format_state_name(args.name, "sum"), Float64, false),
- Field::new(format_state_name(args.name, "count"), Float64, false),
+ Field::new(format_state_name(args.name, "count"), UInt64, false),
Field::new(format_state_name(args.name, "max"), Float64, false),
Field::new(format_state_name(args.name, "min"), Float64, false),
Field::new_list(
diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
index 844e48f0a4..af2a26fd05 100644
--- a/datafusion/functions-aggregate/src/approx_percentile_cont.rs
+++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs
@@ -214,7 +214,7 @@ impl AggregateUDFImpl for ApproxPercentileCont {
),
Field::new(
format_state_name(args.name, "count"),
- DataType::Float64,
+ DataType::UInt64,
false,
),
Field::new(
@@ -406,7 +406,7 @@ impl Accumulator for ApproxPercentileAccumulator {
}
fn evaluate(&mut self) -> datafusion_common::Result<ScalarValue> {
- if self.digest.count() == 0.0 {
+ if self.digest.count() == 0 {
return ScalarValue::try_from(self.return_type.clone());
}
let q = self.digest.estimate_quantile(self.percentile);
@@ -487,8 +487,8 @@ mod tests {
ApproxPercentileAccumulator::new_with_max_size(0.5,
DataType::Float64, 100);
accumulator.merge_digests(&[t1]);
- assert_eq!(accumulator.digest.count(), 50_000.0);
+ assert_eq!(accumulator.digest.count(), 50_000);
accumulator.merge_digests(&[t2]);
- assert_eq!(accumulator.digest.count(), 100_000.0);
+ assert_eq!(accumulator.digest.count(), 100_000);
}
}
diff --git a/datafusion/physical-expr-common/src/aggregate/tdigest.rs
b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
index 1da3d7180d..070ebc4648 100644
--- a/datafusion/physical-expr-common/src/aggregate/tdigest.rs
+++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs
@@ -47,6 +47,17 @@ macro_rules! cast_scalar_f64 {
};
}
+// Cast a non-null [`ScalarValue::UInt64`] to an [`u64`], or
+// panic.
+macro_rules! cast_scalar_u64 {
+ ($value:expr ) => {
+ match &$value {
+ ScalarValue::UInt64(Some(v)) => *v,
+ v => panic!("invalid type {:?}", v),
+ }
+ };
+}
+
/// This trait is implemented for each type a [`TDigest`] can operate on,
/// allowing it to support both numerical rust types (obtained from
/// `PrimitiveArray` instances), and [`ScalarValue`] instances.
@@ -142,7 +153,7 @@ pub struct TDigest {
centroids: Vec<Centroid>,
max_size: usize,
sum: f64,
- count: f64,
+ count: u64,
max: f64,
min: f64,
}
@@ -153,7 +164,7 @@ impl TDigest {
centroids: Vec::new(),
max_size,
sum: 0_f64,
- count: 0_f64,
+ count: 0,
max: f64::NAN,
min: f64::NAN,
}
@@ -164,14 +175,14 @@ impl TDigest {
centroids: vec![centroid.clone()],
max_size,
sum: centroid.mean * centroid.weight,
- count: 1_f64,
+ count: 1,
max: centroid.mean,
min: centroid.mean,
}
}
#[inline]
- pub fn count(&self) -> f64 {
+ pub fn count(&self) -> u64 {
self.count
}
@@ -203,7 +214,7 @@ impl Default for TDigest {
centroids: Vec::new(),
max_size: 100,
sum: 0_f64,
- count: 0_f64,
+ count: 0,
max: f64::NAN,
min: f64::NAN,
}
@@ -211,8 +222,8 @@ impl Default for TDigest {
}
impl TDigest {
- fn k_to_q(k: f64, d: f64) -> f64 {
- let k_div_d = k / d;
+ fn k_to_q(k: u64, d: usize) -> f64 {
+ let k_div_d = k as f64 / d as f64;
if k_div_d >= 0.5 {
let base = 1.0 - k_div_d;
1.0 - 2.0 * base * base
@@ -244,12 +255,12 @@ impl TDigest {
}
let mut result = TDigest::new(self.max_size());
- result.count = self.count() + (sorted_values.len() as f64);
+ result.count = self.count() + sorted_values.len() as u64;
let maybe_min = *sorted_values.first().unwrap();
let maybe_max = *sorted_values.last().unwrap();
- if self.count() > 0.0 {
+ if self.count() > 0 {
result.min = self.min.min(maybe_min);
result.max = self.max.max(maybe_max);
} else {
@@ -259,10 +270,10 @@ impl TDigest {
let mut compressed: Vec<Centroid> = Vec::with_capacity(self.max_size);
- let mut k_limit: f64 = 1.0;
+ let mut k_limit: u64 = 1;
let mut q_limit_times_count =
- Self::k_to_q(k_limit, self.max_size as f64) * result.count();
- k_limit += 1.0;
+ Self::k_to_q(k_limit, self.max_size) * result.count() as f64;
+ k_limit += 1;
let mut iter_centroids = self.centroids.iter().peekable();
let mut iter_sorted_values = sorted_values.iter().peekable();
@@ -309,8 +320,8 @@ impl TDigest {
compressed.push(curr.clone());
q_limit_times_count =
- Self::k_to_q(k_limit, self.max_size as f64) *
result.count();
- k_limit += 1.0;
+ Self::k_to_q(k_limit, self.max_size) * result.count() as
f64;
+ k_limit += 1;
curr = next;
}
}
@@ -381,7 +392,7 @@ impl TDigest {
let mut centroids: Vec<Centroid> = Vec::with_capacity(n_centroids);
let mut starts: Vec<usize> = Vec::with_capacity(digests.len());
- let mut count: f64 = 0.0;
+ let mut count = 0;
let mut min = f64::INFINITY;
let mut max = f64::NEG_INFINITY;
@@ -389,8 +400,8 @@ impl TDigest {
for digest in digests.iter() {
starts.push(start);
- let curr_count: f64 = digest.count();
- if curr_count > 0.0 {
+ let curr_count = digest.count();
+ if curr_count > 0 {
min = min.min(digest.min);
max = max.max(digest.max);
count += curr_count;
@@ -424,8 +435,8 @@ impl TDigest {
let mut result = TDigest::new(max_size);
let mut compressed: Vec<Centroid> = Vec::with_capacity(max_size);
- let mut k_limit: f64 = 1.0;
- let mut q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) *
(count);
+ let mut k_limit = 1;
+ let mut q_limit_times_count = Self::k_to_q(k_limit, max_size) * count
as f64;
let mut iter_centroids = centroids.iter_mut();
let mut curr = iter_centroids.next().unwrap();
@@ -444,8 +455,8 @@ impl TDigest {
sums_to_merge = 0_f64;
weights_to_merge = 0_f64;
compressed.push(curr.clone());
- q_limit_times_count = Self::k_to_q(k_limit, max_size as f64) *
(count);
- k_limit += 1.0;
+ q_limit_times_count = Self::k_to_q(k_limit, max_size) * count
as f64;
+ k_limit += 1;
curr = centroid;
}
}
@@ -468,8 +479,7 @@ impl TDigest {
return 0.0;
}
- let count_ = self.count;
- let rank = q * count_;
+ let rank = q * self.count as f64;
let mut pos: usize;
let mut t;
@@ -479,7 +489,7 @@ impl TDigest {
}
pos = 0;
- t = count_;
+ t = self.count as f64;
for (k, centroid) in self.centroids.iter().enumerate().rev() {
t -= centroid.weight();
@@ -581,7 +591,7 @@ impl TDigest {
vec![
ScalarValue::UInt64(Some(self.max_size as u64)),
ScalarValue::Float64(Some(self.sum)),
- ScalarValue::Float64(Some(self.count)),
+ ScalarValue::UInt64(Some(self.count)),
ScalarValue::Float64(Some(self.max)),
ScalarValue::Float64(Some(self.min)),
ScalarValue::List(arr),
@@ -627,7 +637,7 @@ impl TDigest {
Self {
max_size,
sum: cast_scalar_f64!(state[1]),
- count: cast_scalar_f64!(&state[2]),
+ count: cast_scalar_u64!(&state[2]),
max,
min,
centroids,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]