alamb commented on code in PR #6800:
URL: https://github.com/apache/arrow-datafusion/pull/6800#discussion_r1249750870


##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -383,6 +417,234 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types, and does overflow checking
+#[derive(Debug)]
+struct AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    /// The type of the internal sum
+    sum_data_type: DataType,
+
+    /// The type of the returned sum
+    return_data_type: DataType,
+
+    /// Count per group (use u64 to make UInt64Array)
+    counts: Vec<u64>,
+
+    /// Sums per group, stored as the native type
+    sums: Vec<T::Native>,
+
+    /// Function that computes the average (value / count)
+    avg_fn: F,
+}
+
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: 
F) -> Self {
+        debug!(
+            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> 
{return_data_type:?}",
+            std::any::type_name::<T>()
+        );
+        Self {
+            return_data_type: return_data_type.clone(),
+            sum_data_type: sum_data_type.clone(),
+            counts: vec![],
+            sums: vec![],
+            avg_fn,
+        }
+    }
+
+    /// Adds the values in `values` to self.sums
+    fn update_sums(
+        &mut self,
+        values: &PrimitiveArray<T>,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.sums
+            .resize_with(total_num_groups, || T::default_value());
+
+        // AAL TODO
+        // TODO combine the null mask from values and opt_filter
+        let valids = values.nulls();
+
+        // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
+        let data: &[T::Native] = values.values();
+
+        match valids {
+            // use all values in group_index
+            None => {
+                let iter = group_indicies.iter().zip(data.iter());
+                for (group_index, new_value) in iter {
+                    let sum = &mut self.sums[*group_index];
+                    *sum = sum.add_wrapping(*new_value);
+                }
+            }
+            //
+            Some(valids) => {
+                let group_indices_chunks = group_indicies.chunks_exact(64);
+                let data_chunks = data.chunks_exact(64);
+                let bit_chunks = valids.inner().bit_chunks();
+
+                let group_indices_remainder = group_indices_chunks.remainder();
+                let data_remainder = data_chunks.remainder();
+
+                group_indices_chunks
+                    .zip(data_chunks)
+                    .zip(bit_chunks.iter())
+                    .for_each(|((group_index_chunk, data_chunk), mask)| {
+                        // index_mask has value 1 << i in the loop
+                        let mut index_mask = 1;
+                        
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+                            |(group_index, new_value)| {
+                                if (mask & index_mask) != 0 {
+                                    let sum = &mut self.sums[*group_index];
+                                    *sum = sum.add_wrapping(*new_value);
+                                }
+                                index_mask <<= 1;
+                            },
+                        )
+                    });
+
+                let remainder_bits = bit_chunks.remainder_bits();
+                group_indices_remainder
+                    .iter()
+                    .zip(data_remainder.iter())
+                    .enumerate()
+                    .for_each(|(i, (group_index, new_value))| {
+                        if remainder_bits & (1 << i) != 0 {
+                            let sum = &mut self.sums[*group_index];
+                            *sum = sum.add_wrapping(*new_value);
+                        }
+                    });
+            }
+        }
+        Ok(())
+    }
+}
+
+impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 1, "single argument to update_batch");
+        let values = values.get(0).unwrap().as_primitive::<T>();
+
+        // update counts (TOD account for opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().for_each(|&group_idx| {
+            self.counts[group_idx] += 1;
+        });
+
+        // update values
+        self.update_sums(values, group_indicies, opt_filter, 
total_num_groups)?;
+
+        Ok(())
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 2, "two arguments to merge_batch");
+        // first batch is counts, second is partial sums
+        let counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
+        let partial_sums = values.get(1).unwrap().as_primitive::<T>();
+
+        // update counts by summing the partial sums (TODO account for 
opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().zip(counts.values().iter()).for_each(
+            |(&group_idx, &count)| {
+                self.counts[group_idx] += count;
+            },
+        );
+
+        // update values
+        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups)?;
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ArrayRef> {
+        let counts = std::mem::take(&mut self.counts);
+        let sums = std::mem::take(&mut self.sums);
+
+        let averages: Vec<T::Native> = sums
+            .into_iter()
+            .zip(counts.into_iter())
+            .map(|(sum, count)| (self.avg_fn)(sum, count))
+            .collect::<Result<Vec<_>>>()?;

Review Comment:
   It worked great:
   ```diff
   diff --git a/datafusion/physical-expr/src/aggregate/average.rs 
b/datafusion/physical-expr/src/aggregate/average.rs
   index ee249f3bd1..8de2460a17 100644
   --- a/datafusion/physical-expr/src/aggregate/average.rs
   +++ b/datafusion/physical-expr/src/aggregate/average.rs
   @@ -620,8 +620,9 @@ where
                .map(|(sum, count)| (self.avg_fn)(sum, count))
                .collect::<Result<Vec<_>>>()?;
    
   -        // TODO figure out how to do this without the iter / copy
   -        let array = PrimitiveArray::<T>::from_iter_values(averages);
   +        // Create a primitive array (without a copy)
   +        let nulls = None; // TODO implement null handling
   +        let array = PrimitiveArray::<T>::new(averages.into(), nulls);
    
            // fix up decimal precision and scale for decimals
            let array = adjust_output_array(&self.return_data_type, 
Arc::new(array))?;
   @@ -637,8 +638,8 @@ where
    
            let sums = std::mem::take(&mut self.sums);
            // create array from vec is zero copy
   -        // TODO figure out how to do this without the iter / copy
   -        let sums: PrimitiveArray<T> = 
PrimitiveArray::from_iter_values(sums);
   +        let nulls = None; // TODO implement null handling
   +        let sums = PrimitiveArray::<T>::new(sums.into(), nulls);
    
            // fix up decimal precision and scale for decimals
            let sums = adjust_output_array(&self.sum_data_type, 
Arc::new(sums))?;
   ```
   
   👍 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to