Dandandan commented on code in PR #7358:
URL: https://github.com/apache/arrow-datafusion/pull/7358#discussion_r1301400564


##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -208,97 +217,164 @@ impl PartialEq<dyn Any> for Avg {
 }
 
 /// An accumulator to compute the average
-#[derive(Debug)]
+#[derive(Debug, Default)]
 pub struct AvgAccumulator {
-    // sum is used for null
-    sum: ScalarValue,
-    sum_data_type: DataType,
-    return_data_type: DataType,
+    sum: Option<f64>,
     count: u64,
+    cast_input: bool,
 }
 
 impl AvgAccumulator {
-    /// Creates a new `AvgAccumulator`
-    pub fn try_new(datatype: &DataType, return_data_type: &DataType) -> 
Result<Self> {
-        Ok(Self {
-            sum: ScalarValue::try_from(datatype)?,
-            sum_data_type: datatype.clone(),
-            return_data_type: return_data_type.clone(),
-            count: 0,
+    /// Create a new [`AvgAccumulator`]
+    ///
+    /// If `cast_input` is `true` this will automatically cast input to `f64`
+    pub fn new(cast_input: bool) -> Self {
+        Self {
+            cast_input,
+            ..Default::default()
+        }
+    }
+
+    fn cast_input(&self, input: &ArrayRef) -> Result<ArrayRef> {
+        Ok(match self.cast_input {
+            true => cast(input, &DataType::Float64)?,
+            false => input.clone(),
         })
     }
 }
 
 impl Accumulator for AvgAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
-        Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
+        Ok(vec![
+            ScalarValue::from(self.count),
+            ScalarValue::Float64(self.sum),
+        ])
     }
 
     fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
+        let values = self.cast_input(&values[0])?;
+        let values = values.as_primitive::<Float64Type>();
 
         self.count += (values.len() - values.null_count()) as u64;
-        self.sum = self
-            .sum
-            .add(&sum::sum_batch(values, &self.sum_data_type)?)?;
+        if let Some(x) = sum(values) {
+            let v = self.sum.get_or_insert(0.);
+            *v += x;
+        }
         Ok(())
     }
 
     fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
-        let values = &values[0];
+        let values = self.cast_input(&values[0])?;
+        let values = values.as_primitive::<Float64Type>();
         self.count -= (values.len() - values.null_count()) as u64;
-        let delta = sum_batch(values, &self.sum.get_datatype())?;
-        self.sum = self.sum.sub(&delta)?;
+        if let Some(x) = sum(values) {
+            self.sum = Some(self.sum.unwrap() - x);

Review Comment:
   I think this is expected for floats (otherwise we would need to keep 
intermediate values). Switching to decimal should allow for precise values.
   
   FYI @ozankabak @metesynnada 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to