This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch cleanup_sum_accumulator in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit b96d75ded20d1d0ea64a9166b94f10f239d43195 Author: Daniƫl Heres <[email protected]> AuthorDate: Fri Apr 7 13:31:07 2023 +0200 Cleanup sum accumulator --- datafusion/physical-expr/src/aggregate/sum.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/sum.rs b/datafusion/physical-expr/src/aggregate/sum.rs index 2f92aa9393..6df6674c22 100644 --- a/datafusion/physical-expr/src/aggregate/sum.rs +++ b/datafusion/physical-expr/src/aggregate/sum.rs @@ -139,7 +139,6 @@ impl AggregateExpr for Sum { #[derive(Debug)] struct SumAccumulator { sum: ScalarValue, - count: u64, } impl SumAccumulator { @@ -147,7 +146,6 @@ impl SumAccumulator { pub fn try_new(data_type: &DataType) -> Result<Self> { Ok(Self { sum: ScalarValue::try_from(data_type)?, - count: 0, }) } } @@ -231,12 +229,15 @@ pub(crate) fn add_to_row( impl Accumulator for SumAccumulator { fn state(&self) -> Result<Vec<ScalarValue>> { - Ok(vec![self.sum.clone(), ScalarValue::from(self.count)]) + Ok(vec![self.sum.clone()]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; +<<<<<<< Updated upstream self.count += (values.len() - values.data().null_count()) as u64; +======= +>>>>>>> Stashed changes let delta = sum_batch(values, &self.sum.get_datatype())?; self.sum = self.sum.add(&delta)?; Ok(()) @@ -244,7 +245,10 @@ impl Accumulator for SumAccumulator { fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let values = &values[0]; +<<<<<<< Updated upstream self.count -= (values.len() - values.data().null_count()) as u64; +======= +>>>>>>> Stashed changes let delta = sum_batch(values, &self.sum.get_datatype())?; self.sum = self.sum.sub(&delta)?; Ok(()) @@ -258,11 +262,7 @@ impl Accumulator for SumAccumulator { fn evaluate(&self) -> Result<ScalarValue> { // TODO: add the checker for overflow // For the decimal(precision,_) data type, the absolute of value must be less than 10^precision. - if self.count == 0 { - ScalarValue::try_from(&self.sum.get_datatype()) - } else { - Ok(self.sum.clone()) - } + Ok(self.sum.clone()) } fn size(&self) -> usize {
