This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 337353810df503d02245de02357fb1d6ba04f675 Author: Andrew Lamb <[email protected]> AuthorDate: Fri Jun 30 11:28:48 2023 -0400 complete accumulator --- datafusion/physical-expr/src/aggregate/average.rs | 76 ++++++++++++++++++----- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/average.rs b/datafusion/physical-expr/src/aggregate/average.rs index f81c704d8b..b23b555805 100644 --- a/datafusion/physical-expr/src/aggregate/average.rs +++ b/datafusion/physical-expr/src/aggregate/average.rs @@ -18,7 +18,7 @@ //! Defines physical expressions that can evaluated at runtime during query execution use arrow::array::AsArray; -use log::info; +use log::debug; use std::any::Any; use std::convert::TryFrom; @@ -45,6 +45,8 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Accumulator; use datafusion_row::accessor::RowAccessor; +use super::utils::Decimal128Averager; + /// AVG aggregate expression #[derive(Debug, Clone)] pub struct Avg { @@ -161,16 +163,29 @@ impl AggregateExpr for Avg { fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> { // instantiate specialized accumulator - match self.sum_data_type { - DataType::Decimal128(_, _) => { - Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type>::new( + match (&self.sum_data_type, &self.rt_data_type) { + ( + DataType::Decimal128(_sum_precision, sum_scale), + DataType::Decimal128(target_precision, target_scale), + ) => { + let decimal_averager = Decimal128Averager::try_new( + *sum_scale, + *target_precision, + *target_scale, + )?; + + let avg_fn = + move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); + + Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new( &self.sum_data_type, &self.rt_data_type, + avg_fn, ))) } _ => Err(DataFusionError::NotImplemented(format!( - "AvgGroupsAccumulator for {}", - self.sum_data_type + "AvgGroupsAccumulator for ({} --> {})", + self.sum_data_type, self.rt_data_type, ))), } } @@ -403,9 +418,13 @@ impl RowAccumulator for AvgRowAccumulator { } /// An accumulator to compute the average of PrimitiveArray<T>. -/// Stores values as native types +/// Stores values as native types, and does overflow checking #[derive(Debug)] -struct AvgGroupsAccumulator<T: ArrowNumericType + Send> { +struct AvgGroupsAccumulator<T, F> +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result<T::Native> + Send, +{ /// The type of the internal sum sum_data_type: DataType, @@ -415,13 +434,20 @@ struct AvgGroupsAccumulator<T: ArrowNumericType + Send> { /// Count per group (use u64 to make UInt64Array) counts: Vec<u64>, - // Sums per group, stored as the native type + /// Sums per group, stored as the native type sums: Vec<T::Native>, + + /// Function that computes the average (value / count) + avg_fn: F, } -impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> { - pub fn new(sum_data_type: &DataType, return_data_type: &DataType) -> Self { - info!( +impl<T, F> AvgGroupsAccumulator<T, F> +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result<T::Native> + Send, +{ + pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + debug!( "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> {return_data_type:?}", std::any::type_name::<T>() ); @@ -430,6 +456,7 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> { sum_data_type: sum_data_type.clone(), counts: vec![], sums: vec![], + avg_fn, } } @@ -500,7 +527,11 @@ impl<T: ArrowNumericType + Send> AvgGroupsAccumulator<T> { } } -impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> { +impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F> +where + T: ArrowNumericType + Send, + F: Fn(T::Native, u64) -> Result<T::Native> + Send, +{ fn update_batch( &mut self, values: &[ArrayRef], @@ -549,7 +580,22 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> { } fn evaluate(&mut self) -> Result<ArrayRef> { - todo!() + let counts = std::mem::take(&mut self.counts); + let sums = std::mem::take(&mut self.sums); + + let averages: Vec<T::Native> = sums + .into_iter() + .zip(counts.into_iter()) + .map(|(sum, count)| (self.avg_fn)(sum, count)) + .collect::<Result<Vec<_>>>()?; + + // TODO figure out how to do this without the iter / copy + let array = PrimitiveArray::<T>::from_iter_values(averages); + + // fix up decimal precision and scale for decimals + let array = set_decimal_precision(&self.return_data_type, Arc::new(array))?; + + Ok(array) } // return arrays for sums and counts @@ -563,7 +609,7 @@ impl<T: ArrowNumericType + Send> GroupsAccumulator for AvgGroupsAccumulator<T> { // TODO figure out how to do this without the iter / copy let sums: PrimitiveArray<T> = PrimitiveArray::from_iter_values(sums); - // fix up decimal precision and scale + // fix up decimal precision and scale for decimals let sums = set_decimal_precision(&self.sum_data_type, Arc::new(sums))?; Ok(vec