metesynnada commented on code in PR #7358:
URL: https://github.com/apache/arrow-datafusion/pull/7358#discussion_r1301426487
##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -208,97 +217,164 @@ impl PartialEq<dyn Any> for Avg {
}
/// An accumulator to compute the average
-#[derive(Debug)]
+#[derive(Debug, Default)]
pub struct AvgAccumulator {
- // sum is used for null
- sum: ScalarValue,
- sum_data_type: DataType,
- return_data_type: DataType,
+ sum: Option<f64>,
count: u64,
+ cast_input: bool,
}
impl AvgAccumulator {
- /// Creates a new `AvgAccumulator`
- pub fn try_new(datatype: &DataType, return_data_type: &DataType) ->
Result<Self> {
- Ok(Self {
- sum: ScalarValue::try_from(datatype)?,
- sum_data_type: datatype.clone(),
- return_data_type: return_data_type.clone(),
- count: 0,
+ /// Create a new [`AvgAccumulator`]
+ ///
+ /// If `cast_input` is `true` this will automatically cast input to `f64`
+ pub fn new(cast_input: bool) -> Self {
+ Self {
+ cast_input,
+ ..Default::default()
+ }
+ }
+
+ fn cast_input(&self, input: &ArrayRef) -> Result<ArrayRef> {
+ Ok(match self.cast_input {
+ true => cast(input, &DataType::Float64)?,
+ false => input.clone(),
})
}
}
impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
- Ok(vec![ScalarValue::from(self.count), self.sum.clone()])
+ Ok(vec![
+ ScalarValue::from(self.count),
+ ScalarValue::Float64(self.sum),
+ ])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = &values[0];
+ let values = self.cast_input(&values[0])?;
+ let values = values.as_primitive::<Float64Type>();
self.count += (values.len() - values.null_count()) as u64;
- self.sum = self
- .sum
- .add(&sum::sum_batch(values, &self.sum_data_type)?)?;
+ if let Some(x) = sum(values) {
+ let v = self.sum.get_or_insert(0.);
+ *v += x;
+ }
Ok(())
}
fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
- let values = &values[0];
+ let values = self.cast_input(&values[0])?;
+ let values = values.as_primitive::<Float64Type>();
self.count -= (values.len() - values.null_count()) as u64;
- let delta = sum_batch(values, &self.sum.get_datatype())?;
- self.sum = self.sum.sub(&delta)?;
+ if let Some(x) = sum(values) {
+ self.sum = Some(self.sum.unwrap() - x);
Review Comment:
We will be looking into this.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]