This is an automated email from the ASF dual-hosted git repository. jiayuliu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push: new 82e8003 remove update and merge (#1582) 82e8003 is described below commit 82e80036d9dbcf1bacac8e94eb10d5c60a8016a7 Author: Jiayu Liu <jimex...@users.noreply.github.com> AuthorDate: Tue Jan 18 09:52:13 2022 +0800 remove update and merge (#1582) --- README.md | 2 +- datafusion-examples/examples/simple_udaf.rs | 58 ++++++++++++++++------ datafusion/src/datasource/file_format/parquet.rs | 48 ++++++++++++++---- datafusion/src/datasource/mod.rs | 4 +- .../src/physical_plan/distinct_expressions.rs | 50 ++++++++++++++++--- .../physical_plan/expressions/approx_distinct.rs | 17 ------- .../src/physical_plan/expressions/array_agg.rs | 43 ++++++++++++---- .../src/physical_plan/expressions/average.rs | 8 --- .../src/physical_plan/expressions/correlation.rs | 8 --- datafusion/src/physical_plan/expressions/count.rs | 8 --- .../src/physical_plan/expressions/covariance.rs | 8 --- .../src/physical_plan/expressions/min_max.rs | 20 -------- datafusion/src/physical_plan/expressions/stddev.rs | 8 --- datafusion/src/physical_plan/expressions/sum.rs | 8 --- .../src/physical_plan/expressions/variance.rs | 8 --- datafusion/src/physical_plan/mod.rs | 51 ++----------------- 16 files changed, 164 insertions(+), 185 deletions(-) diff --git a/README.md b/README.md index 0c7c76f..5e32bf7 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ the convenience of an SQL interface or a DataFrame API. ## Known Uses -Projects that adapt to or service as plugins to DataFusion: +Projects that adapt to or serve as plugins to DataFusion: - [datafusion-python](https://github.com/datafusion-contrib/datafusion-python) - [datafusion-java](https://github.com/datafusion-contrib/datafusion-java) diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 08706a3..3acace2 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -18,7 +18,7 @@ /// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. /// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean use datafusion::arrow::{ - array::Float32Array, array::Float64Array, datatypes::DataType, + array::ArrayRef, array::Float32Array, array::Float64Array, datatypes::DataType, record_batch::RecordBatch, }; @@ -66,20 +66,6 @@ impl GeometricMean { pub fn new() -> Self { GeometricMean { n: 0, prod: 1.0 } } -} - -// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions -// to use them. -impl Accumulator for GeometricMean { - // this function serializes our state to `ScalarValue`, which DataFusion uses - // to pass this state between execution stages. - // Note that this can be arbitrary data. - fn state(&self) -> Result<Vec<ScalarValue>> { - Ok(vec![ - ScalarValue::from(self.prod), - ScalarValue::from(self.n), - ]) - } // this function receives one entry per argument of this accumulator. // DataFusion calls this function on every row, and expects this function to update the accumulator's state. @@ -114,6 +100,20 @@ impl Accumulator for GeometricMean { }; Ok(()) } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // This function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result<Vec<ScalarValue>> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } // DataFusion expects this function to return the final value of this aggregator. // in this case, this is the formula of the geometric mean @@ -122,9 +122,37 @@ impl Accumulator for GeometricMean { Ok(ScalarValue::from(value)) } + // DataFusion calls this function to update the accumulator's state for a batch + // of inputs rows. In this case the product is updated with values from the first column + // and the count is updated based on the row count + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + (0..values[0].len()).try_for_each(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.update(&v) + }) + } + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, // that can be used to perform these operations on arrays instead of single values. // By default, these methods call `update` and `merge` row by row + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + (0..states[0].len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.merge(&v) + }) + } } #[tokio::main] diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index a947518..eedb4c9 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -36,6 +36,9 @@ use parquet::file::statistics::Statistics as ParquetStatistics; use super::FileFormat; use super::FileScanConfig; +use crate::arrow::array::{ + BooleanArray, Float32Array, Float64Array, Int32Array, Int64Array, +}; use crate::arrow::datatypes::{DataType, Field}; use crate::datasource::object_store::{ObjectReader, ObjectReaderStream}; use crate::datasource::{create_max_min_accs, get_col_stats}; @@ -47,7 +50,6 @@ use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::file_format::ParquetExec; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::{Accumulator, Statistics}; -use crate::scalar::ScalarValue; /// The default file exetension of parquet files pub const DEFAULT_PARQUET_EXTENSION: &str = ".parquet"; @@ -132,7 +134,9 @@ fn summarize_min_max( if let DataType::Boolean = fields[i].data_type() { if s.has_min_max_set() { if let Some(max_value) = &mut max_values[i] { - match max_value.update(&[ScalarValue::Boolean(Some(*s.max()))]) { + match max_value.update_batch(&[Arc::new(BooleanArray::from( + vec![Some(*s.max())], + ))]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -140,7 +144,9 @@ fn summarize_min_max( } } if let Some(min_value) = &mut min_values[i] { - match min_value.update(&[ScalarValue::Boolean(Some(*s.min()))]) { + match min_value.update_batch(&[Arc::new(BooleanArray::from( + vec![Some(*s.min())], + ))]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -154,7 +160,10 @@ fn summarize_min_max( if let DataType::Int32 = fields[i].data_type() { if s.has_min_max_set() { if let Some(max_value) = &mut max_values[i] { - match max_value.update(&[ScalarValue::Int32(Some(*s.max()))]) { + match max_value.update_batch(&[Arc::new(Int32Array::from_value( + *s.max(), + 1, + ))]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -162,7 +171,10 @@ fn summarize_min_max( } } if let Some(min_value) = &mut min_values[i] { - match min_value.update(&[ScalarValue::Int32(Some(*s.min()))]) { + match min_value.update_batch(&[Arc::new(Int32Array::from_value( + *s.min(), + 1, + ))]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -176,7 +188,10 @@ fn summarize_min_max( if let DataType::Int64 = fields[i].data_type() { if s.has_min_max_set() { if let Some(max_value) = &mut max_values[i] { - match max_value.update(&[ScalarValue::Int64(Some(*s.max()))]) { + match max_value.update_batch(&[Arc::new(Int64Array::from_value( + *s.max(), + 1, + ))]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -184,7 +199,10 @@ fn summarize_min_max( } } if let Some(min_value) = &mut min_values[i] { - match min_value.update(&[ScalarValue::Int64(Some(*s.min()))]) { + match min_value.update_batch(&[Arc::new(Int64Array::from_value( + *s.min(), + 1, + ))]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -198,7 +216,9 @@ fn summarize_min_max( if let DataType::Float32 = fields[i].data_type() { if s.has_min_max_set() { if let Some(max_value) = &mut max_values[i] { - match max_value.update(&[ScalarValue::Float32(Some(*s.max()))]) { + match max_value.update_batch(&[Arc::new(Float32Array::from( + vec![Some(*s.max())], + ))]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -206,7 +226,9 @@ fn summarize_min_max( } } if let Some(min_value) = &mut min_values[i] { - match min_value.update(&[ScalarValue::Float32(Some(*s.min()))]) { + match min_value.update_batch(&[Arc::new(Float32Array::from( + vec![Some(*s.min())], + ))]) { Ok(_) => {} Err(_) => { min_values[i] = None; @@ -220,7 +242,9 @@ fn summarize_min_max( if let DataType::Float64 = fields[i].data_type() { if s.has_min_max_set() { if let Some(max_value) = &mut max_values[i] { - match max_value.update(&[ScalarValue::Float64(Some(*s.max()))]) { + match max_value.update_batch(&[Arc::new(Float64Array::from( + vec![Some(*s.max())], + ))]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -228,7 +252,9 @@ fn summarize_min_max( } } if let Some(min_value) = &mut min_values[i] { - match min_value.update(&[ScalarValue::Float64(Some(*s.min()))]) { + match min_value.update_batch(&[Arc::new(Float64Array::from( + vec![Some(*s.min())], + ))]) { Ok(_) => {} Err(_) => { min_values[i] = None; diff --git a/datafusion/src/datasource/mod.rs b/datafusion/src/datasource/mod.rs index 6e119f0..33512b4 100644 --- a/datafusion/src/datasource/mod.rs +++ b/datafusion/src/datasource/mod.rs @@ -71,7 +71,7 @@ pub async fn get_statistics_with_limit( if let Some(max_value) = &mut max_values[i] { if let Some(file_max) = cs.max_value.clone() { - match max_value.update(&[file_max]) { + match max_value.update_batch(&[file_max.to_array()]) { Ok(_) => {} Err(_) => { max_values[i] = None; @@ -82,7 +82,7 @@ pub async fn get_statistics_with_limit( if let Some(min_value) = &mut min_values[i] { if let Some(file_min) = cs.min_value.clone() { - match min_value.update(&[file_min]) { + match min_value.update_batch(&[file_min.to_array()]) { Ok(_) => {} Err(_) => { min_values[i] = None; diff --git a/datafusion/src/physical_plan/distinct_expressions.rs b/datafusion/src/physical_plan/distinct_expressions.rs index c127042..080308a 100644 --- a/datafusion/src/physical_plan/distinct_expressions.rs +++ b/datafusion/src/physical_plan/distinct_expressions.rs @@ -17,13 +17,13 @@ //! Implementations for DISTINCT expressions, e.g. `COUNT(DISTINCT c)` +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; -use arrow::datatypes::{DataType, Field}; - use ahash::RandomState; use std::collections::HashSet; @@ -130,8 +130,7 @@ struct DistinctCountAccumulator { state_data_types: Vec<DataType>, count_data_type: DataType, } - -impl Accumulator for DistinctCountAccumulator { +impl DistinctCountAccumulator { fn update(&mut self, values: &[ScalarValue]) -> Result<()> { // If a row has a NULL, it is not included in the final count. if !values.iter().any(|v| v.is_null()) { @@ -165,7 +164,33 @@ impl Accumulator for DistinctCountAccumulator { self.update(&row_values) }) } +} +impl Accumulator for DistinctCountAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + (0..values[0].len()).try_for_each(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.update(&v) + }) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + (0..states[0].len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.merge(&v) + }) + } fn state(&self) -> Result<Vec<ScalarValue>> { let mut cols_out = self .state_data_types @@ -317,9 +342,20 @@ mod tests { let mut accum = agg.create_accumulator()?; - for row in rows.iter() { - accum.update(row)? - } + let cols = (0..rows[0].len()) + .map(|i| { + rows.iter() + .map(|inner| inner[i].clone()) + .collect::<Vec<ScalarValue>>() + }) + .collect::<Vec<_>>(); + + let arrays: Vec<ArrayRef> = cols + .iter() + .map(|c| ScalarValue::iter_to_array(c.clone())) + .collect::<Result<Vec<ArrayRef>>>()?; + + accum.update_batch(&arrays)?; Ok((accum.state()?, accum.evaluate()?)) } diff --git a/datafusion/src/physical_plan/expressions/approx_distinct.rs b/datafusion/src/physical_plan/expressions/approx_distinct.rs index ac7dcb3..9900780 100644 --- a/datafusion/src/physical_plan/expressions/approx_distinct.rs +++ b/datafusion/src/physical_plan/expressions/approx_distinct.rs @@ -217,23 +217,6 @@ impl<T: Hash> TryFrom<&ScalarValue> for HyperLogLog<T> { macro_rules! default_accumulator_impl { () => { - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - self.update_batch( - values - .iter() - .map(|s| s.to_array() as ArrayRef) - .collect::<Vec<_>>() - .as_slice(), - ) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - assert_eq!(1, states.len(), "expect only 1 element in the states"); - let other = HyperLogLog::try_from(&states[0])?; - self.hll.merge(&other); - Ok(()) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { assert_eq!(1, states.len(), "expect only 1 element in the states"); let binary_array = states[0].as_any().downcast_ref::<BinaryArray>().unwrap(); diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs b/datafusion/src/physical_plan/expressions/array_agg.rs index f912615..c237cc0 100644 --- a/datafusion/src/physical_plan/expressions/array_agg.rs +++ b/datafusion/src/physical_plan/expressions/array_agg.rs @@ -21,6 +21,7 @@ use super::format_state_name; use crate::error::Result; use crate::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; use crate::scalar::ScalarValue; +use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Field}; use std::any::Any; use std::sync::Arc; @@ -106,15 +107,6 @@ impl ArrayAggAccumulator { datatype: datatype.clone(), }) } -} - -impl Accumulator for ArrayAggAccumulator { - fn state(&self) -> Result<Vec<ScalarValue>> { - Ok(vec![ScalarValue::List( - Some(Box::new(self.array.clone())), - Box::new(self.datatype.clone()), - )]) - } fn update(&mut self, values: &[ScalarValue]) -> Result<()> { let value = &values[0]; @@ -137,6 +129,39 @@ impl Accumulator for ArrayAggAccumulator { } Ok(()) } +} + +impl Accumulator for ArrayAggAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + }; + (0..values[0].len()).try_for_each(|index| { + let v = values + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.update(&v) + }) + } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + }; + (0..states[0].len()).try_for_each(|index| { + let v = states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::<Result<Vec<_>>>()?; + self.merge(&v) + }) + } + fn state(&self) -> Result<Vec<ScalarValue>> { + Ok(vec![ScalarValue::List( + Some(Box::new(self.array.clone())), + Box::new(self.datatype.clone()), + )]) + } fn evaluate(&self) -> Result<ScalarValue> { Ok(ScalarValue::List( diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index e05a9ad..2b9fa9d 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -171,10 +171,6 @@ impl Accumulator for AvgAccumulator { Ok(vec![ScalarValue::from(self.count), self.sum.clone()]) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; @@ -183,10 +179,6 @@ impl Accumulator for AvgAccumulator { Ok(()) } - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap(); // counts are summed diff --git a/datafusion/src/physical_plan/expressions/correlation.rs b/datafusion/src/physical_plan/expressions/correlation.rs index e91712e..598e031 100644 --- a/datafusion/src/physical_plan/expressions/correlation.rs +++ b/datafusion/src/physical_plan/expressions/correlation.rs @@ -204,14 +204,6 @@ impl Accumulator for CorrelationAccumulator { Ok(()) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn evaluate(&self) -> Result<ScalarValue> { let covar = self.covar.evaluate()?; let stddev1 = self.stddev1.evaluate()?; diff --git a/datafusion/src/physical_plan/expressions/count.rs b/datafusion/src/physical_plan/expressions/count.rs index 5420a7c..830cbf3 100644 --- a/datafusion/src/physical_plan/expressions/count.rs +++ b/datafusion/src/physical_plan/expressions/count.rs @@ -112,14 +112,6 @@ impl Accumulator for CountAccumulator { Ok(()) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap(); let delta = &compute::sum(counts); diff --git a/datafusion/src/physical_plan/expressions/covariance.rs b/datafusion/src/physical_plan/expressions/covariance.rs index 3d1913c..f264062 100644 --- a/datafusion/src/physical_plan/expressions/covariance.rs +++ b/datafusion/src/physical_plan/expressions/covariance.rs @@ -357,14 +357,6 @@ impl Accumulator for CovarianceAccumulator { Ok(()) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn evaluate(&self) -> Result<ScalarValue> { let count = match self.stats_type { StatsType::Population => self.count, diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 8793718..125f2cb 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -440,16 +440,6 @@ impl Accumulator for MaxAccumulator { Ok(()) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.max = max(&self.max, value)?; - Ok(()) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.update(states) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.update_batch(states) } @@ -543,12 +533,6 @@ impl Accumulator for MinAccumulator { Ok(vec![self.min.clone()]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - let value = &values[0]; - self.min = min(&self.min, value)?; - Ok(()) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; let delta = &min_batch(values)?; @@ -556,10 +540,6 @@ impl Accumulator for MinAccumulator { Ok(()) } - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.update(states) - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { self.update_batch(states) } diff --git a/datafusion/src/physical_plan/expressions/stddev.rs b/datafusion/src/physical_plan/expressions/stddev.rs index 4404c5b..b9af721 100644 --- a/datafusion/src/physical_plan/expressions/stddev.rs +++ b/datafusion/src/physical_plan/expressions/stddev.rs @@ -224,14 +224,6 @@ impl Accumulator for StddevAccumulator { ]) } - fn update(&mut self, values: &[ScalarValue]) -> Result<()> { - self.variance.update(values) - } - - fn merge(&mut self, states: &[ScalarValue]) -> Result<()> { - self.variance.merge(states) - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { self.variance.update_batch(values) } diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 9fc5c4d..e8f4420 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -353,20 +353,12 @@ impl Accumulator for SumAccumulator { Ok(vec![self.sum.clone()]) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; self.sum = sum(&self.sum, &sum_batch(values)?)?; Ok(()) } - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { // sum(sum1, sum2, sum3, ...) = sum1 + sum2 + sum3 + ... self.update_batch(states) diff --git a/datafusion/src/physical_plan/expressions/variance.rs b/datafusion/src/physical_plan/expressions/variance.rs index 64294ac..38ee3d7 100644 --- a/datafusion/src/physical_plan/expressions/variance.rs +++ b/datafusion/src/physical_plan/expressions/variance.rs @@ -302,14 +302,6 @@ impl Accumulator for VarianceAccumulator { Ok(()) } - fn update(&mut self, _values: &[ScalarValue]) -> Result<()> { - unimplemented!("update_batch is implemented instead"); - } - - fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> { - unimplemented!("merge_batch is implemented instead"); - } - fn evaluate(&self) -> Result<ScalarValue> { let count = match self.stats_type { StatsType::Population => self.count, diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index b019b52..be59968 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -566,9 +566,9 @@ pub trait WindowExpr: Send + Sync + Debug { /// generically accumulates values. /// /// An accumulator knows how to: -/// * update its state from inputs via `update` +/// * update its state from inputs via `update_batch` /// * convert its internal state to a vector of scalar values -/// * update its state from multiple accumulators' states via `merge` +/// * update its state from multiple accumulators' states via `merge_batch` /// * compute the final value from its internal state via `evaluate` pub trait Accumulator: Send + Sync + Debug { /// Returns the state of the accumulator at the end of the accumulation. @@ -576,54 +576,11 @@ pub trait Accumulator: Send + Sync + Debug { // of two values, sum and n. fn state(&self) -> Result<Vec<ScalarValue>>; - /// Updates the accumulator's state from a vector of scalars - /// (called by default implementation of [`update_batch`]). - /// - /// Note: this method is often the simplest to implement and is - /// backwards compatible to help to lower the barrier to entry for - /// new users to write `Accumulators` - /// - /// You should always implement `update_batch` instead of this - /// method for production aggregators or if you find yourself - /// wanting to use mathematical kernels for [`ScalarValue`] such as - /// `ScalarValue::add`, `ScalarValue::mul`, etc - fn update(&mut self, values: &[ScalarValue]) -> Result<()>; - /// updates the accumulator's state from a vector of arrays. - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - }; - (0..values[0].len()).try_for_each(|index| { - let v = values - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::<Result<Vec<_>>>()?; - self.update(&v) - }) - } - - /// Updates the accumulator's state from a vector of scalars. - /// (called by default implementation of [`merge`]). - /// - /// You should always implement `merge_batch` instead of this - /// method for production aggregators. Please see notes on - /// [`update`] for more detail and rationale. - fn merge(&mut self, states: &[ScalarValue]) -> Result<()>; + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>; /// updates the accumulator's state from a vector of states. - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - }; - (0..states[0].len()).try_for_each(|index| { - let v = states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::<Result<Vec<_>>>()?; - self.merge(&v) - }) - } + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>; /// returns its value based on its current state. fn evaluate(&self) -> Result<ScalarValue>;