Dandandan commented on code in PR #6800:
URL: https://github.com/apache/arrow-datafusion/pull/6800#discussion_r1248737610
##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -383,6 +417,234 @@ impl RowAccumulator for AvgRowAccumulator {
}
}
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types, and does overflow checking
+#[derive(Debug)]
+struct AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+ /// The type of the internal sum
+ sum_data_type: DataType,
+
+ /// The type of the returned sum
+ return_data_type: DataType,
+
+ /// Count per group (use u64 to make UInt64Array)
+ counts: Vec<u64>,
+
+ /// Sums per group, stored as the native type
+ sums: Vec<T::Native>,
+
+ /// Function that computes the average (value / count)
+ avg_fn: F,
+}
+
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+ pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn:
F) -> Self {
+ debug!(
+ "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) -->
{return_data_type:?}",
+ std::any::type_name::<T>()
+ );
+ Self {
+ return_data_type: return_data_type.clone(),
+ sum_data_type: sum_data_type.clone(),
+ counts: vec![],
+ sums: vec![],
+ avg_fn,
+ }
+ }
+
+ /// Adds the values in `values` to self.sums
+ fn update_sums(
+ &mut self,
+ values: &PrimitiveArray<T>,
+ group_indicies: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ self.sums
+ .resize_with(total_num_groups, || T::default_value());
+
+ // AAL TODO
+ // TODO combine the null mask from values and opt_filter
+ let valids = values.nulls();
+
+ // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
+ let data: &[T::Native] = values.values();
+
+ match valids {
+ // use all values in group_index
+ None => {
+ let iter = group_indicies.iter().zip(data.iter());
+ for (group_index, new_value) in iter {
+ let sum = &mut self.sums[*group_index];
+ *sum = sum.add_wrapping(*new_value);
+ }
+ }
+ //
+ Some(valids) => {
+ let group_indices_chunks = group_indicies.chunks_exact(64);
+ let data_chunks = data.chunks_exact(64);
+ let bit_chunks = valids.inner().bit_chunks();
+
+ let group_indices_remainder = group_indices_chunks.remainder();
+ let data_remainder = data_chunks.remainder();
+
+ group_indices_chunks
+ .zip(data_chunks)
+ .zip(bit_chunks.iter())
+ .for_each(|((group_index_chunk, data_chunk), mask)| {
+ // index_mask has value 1 << i in the loop
+ let mut index_mask = 1;
+
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+ |(group_index, new_value)| {
+ if (mask & index_mask) != 0 {
+ let sum = &mut self.sums[*group_index];
+ *sum = sum.add_wrapping(*new_value);
+ }
+ index_mask <<= 1;
+ },
+ )
+ });
+
+ let remainder_bits = bit_chunks.remainder_bits();
+ group_indices_remainder
+ .iter()
+ .zip(data_remainder.iter())
+ .enumerate()
+ .for_each(|(i, (group_index, new_value))| {
+ if remainder_bits & (1 << i) != 0 {
+ let sum = &mut self.sums[*group_index];
+ *sum = sum.add_wrapping(*new_value);
+ }
+ });
+ }
+ }
+ Ok(())
+ }
+}
+
+impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
+where
+ T: ArrowNumericType + Send,
+ F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indicies: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 1, "single argument to update_batch");
+ let values = values.get(0).unwrap().as_primitive::<T>();
+
+ // update counts (TOD account for opt_filter)
+ self.counts.resize(total_num_groups, 0);
+ group_indicies.iter().for_each(|&group_idx| {
+ self.counts[group_idx] += 1;
+ });
+
+ // update values
+ self.update_sums(values, group_indicies, opt_filter,
total_num_groups)?;
+
+ Ok(())
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indicies: &[usize],
+ opt_filter: Option<&arrow_array::BooleanArray>,
+ total_num_groups: usize,
+ ) -> Result<()> {
+ assert_eq!(values.len(), 2, "two arguments to merge_batch");
+ // first batch is counts, second is partial sums
+ let counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
+ let partial_sums = values.get(1).unwrap().as_primitive::<T>();
+
+ // update counts by summing the partial sums (TODO account for
opt_filter)
+ self.counts.resize(total_num_groups, 0);
+ group_indicies.iter().zip(counts.values().iter()).for_each(
+ |(&group_idx, &count)| {
+ self.counts[group_idx] += count;
+ },
+ );
+
+ // update values
+ self.update_sums(partial_sums, group_indicies, opt_filter,
total_num_groups)?;
+
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> Result<ArrayRef> {
+ let counts = std::mem::take(&mut self.counts);
+ let sums = std::mem::take(&mut self.sums);
+
+ let averages: Vec<T::Native> = sums
+ .into_iter()
+ .zip(counts.into_iter())
+ .map(|(sum, count)| (self.avg_fn)(sum, count))
+ .collect::<Result<Vec<_>>>()?;
Review Comment:
When adding a where clause on `AvgGroupsAccumulator` with `NativeAdapter<T>:
From<<T as ArrowPrimitiveType>::Native>`
we can collect it to `PrimitiveArray` instead of collecting into `Vec`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]