alamb commented on code in PR #6800:
URL: https://github.com/apache/arrow-datafusion/pull/6800#discussion_r1248102535


##########
datafusion/core/src/physical_plan/aggregates/row_hash2.rs:
##########
@@ -0,0 +1,449 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Hash aggregation through row format
+//!
+//! POC demonstration of GroupByHashApproach
+
+use datafusion_physical_expr::GroupsAccumulator;
+use log::debug;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::vec;
+
+use ahash::RandomState;
+use arrow::row::{OwnedRow, RowConverter, SortField};
+use datafusion_physical_expr::hash_utils::create_hashes;
+use futures::ready;
+use futures::stream::{Stream, StreamExt};
+
+use crate::physical_plan::aggregates::{
+    evaluate_group_by, evaluate_many, evaluate_optional, group_schema, 
AggregateMode,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
+use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use arrow::array::*;
+use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
+use datafusion_common::Result;
+use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::TaskContext;
+use hashbrown::raw::RawTable;
+
+#[derive(Debug, Clone)]
+/// This object tracks the aggregation phase (input/output)
+pub(crate) enum ExecutionState {
+    ReadingInput,
+    /// When producing output, the remaining rows to output are stored
+    /// here and are sliced off as needed in batch_size chunks
+    ProducingOutput(RecordBatch),
+    Done,
+}
+
+use super::AggregateExec;
+
+/// Grouping aggregate

Review Comment:
   This code follows the basic structure of `row_hash` but the aggregate state 
management is different



##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -383,6 +417,234 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types, and does overflow checking
+#[derive(Debug)]
+struct AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    /// The type of the internal sum
+    sum_data_type: DataType,
+
+    /// The type of the returned sum
+    return_data_type: DataType,
+
+    /// Count per group (use u64 to make UInt64Array)
+    counts: Vec<u64>,
+
+    /// Sums per group, stored as the native type
+    sums: Vec<T::Native>,
+
+    /// Function that computes the average (value / count)
+    avg_fn: F,
+}
+
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: 
F) -> Self {
+        debug!(
+            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> 
{return_data_type:?}",
+            std::any::type_name::<T>()
+        );
+        Self {
+            return_data_type: return_data_type.clone(),
+            sum_data_type: sum_data_type.clone(),
+            counts: vec![],
+            sums: vec![],
+            avg_fn,
+        }
+    }
+
+    /// Adds the values in `values` to self.sums
+    fn update_sums(
+        &mut self,
+        values: &PrimitiveArray<T>,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.sums
+            .resize_with(total_num_groups, || T::default_value());
+
+        // AAL TODO
+        // TODO combine the null mask from values and opt_filter
+        let valids = values.nulls();
+
+        // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum

Review Comment:
   This particular code is likely to be *very* common across most accumulators 
so I would hope to find some way to generalize it into its own function / macro



##########
datafusion/core/src/physical_plan/aggregates/row_hash2.rs:
##########
@@ -0,0 +1,449 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Hash aggregation through row format
+//!
+//! POC demonstration of GroupByHashApproach
+
+use datafusion_physical_expr::GroupsAccumulator;
+use log::debug;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::vec;
+
+use ahash::RandomState;
+use arrow::row::{OwnedRow, RowConverter, SortField};
+use datafusion_physical_expr::hash_utils::create_hashes;
+use futures::ready;
+use futures::stream::{Stream, StreamExt};
+
+use crate::physical_plan::aggregates::{
+    evaluate_group_by, evaluate_many, evaluate_optional, group_schema, 
AggregateMode,
+    PhysicalGroupBy,
+};
+use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
+use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
+use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
+use arrow::array::*;
+use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
+use datafusion_common::Result;
+use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::TaskContext;
+use hashbrown::raw::RawTable;
+
+#[derive(Debug, Clone)]
+/// This object tracks the aggregation phase (input/output)
+pub(crate) enum ExecutionState {
+    ReadingInput,
+    /// When producing output, the remaining rows to output are stored
+    /// here and are sliced off as needed in batch_size chunks
+    ProducingOutput(RecordBatch),
+    Done,
+}
+
+use super::AggregateExec;
+
+/// Grouping aggregate
+///
+/// For each aggregation entry, we use:
+/// - [Arrow-row] represents grouping keys for fast hash computation and 
comparison directly on raw bytes.
+/// - [GroupsAccumulator] to store per group aggregates
+///
+/// The architecture is the following:
+///
+/// TODO
+///
+/// [WordAligned]: datafusion_row::layout
+pub(crate) struct GroupedHashAggregateStream2 {
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    mode: AggregateMode,
+
+    /// Accumulators, one for each `AggregateExpr` in the query
+    accumulators: Vec<Box<dyn GroupsAccumulator>>,
+    /// Arguments expressionf or each accumulator
+    aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+    /// Filter expression to evaluate for each aggregate
+    filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
+
+    /// Converter for each row
+    row_converter: RowConverter,
+    group_by: PhysicalGroupBy,
+
+    /// The memory reservation for this grouping
+    reservation: MemoryReservation,
+
+    /// Logically maps group values to a group_index `group_states`
+    ///
+    /// Uses the raw API of hashbrown to avoid actually storing the
+    /// keys in the table
+    ///
+    /// keys: u64 hashes of the GroupValue
+    /// values: (hash, index into `group_states`)
+    map: RawTable<(u64, usize)>,
+
+    /// The actual group by values, stored in arrow Row format
+    /// the index of group_by_values is the index
+    /// https://github.com/apache/arrow-rs/issues/4466
+    group_by_values: Vec<OwnedRow>,
+
+    /// scratch space for the current Batch / Aggregate being
+    /// processed. Saved here to avoid reallocations
+    current_group_indices: Vec<usize>,
+
+    /// generating input/output?
+    exec_state: ExecutionState,
+
+    baseline_metrics: BaselineMetrics,
+
+    random_state: RandomState,
+    /// size to be used for resulting RecordBatches
+    batch_size: usize,
+}
+
+impl GroupedHashAggregateStream2 {
+    /// Create a new GroupedHashAggregateStream
+    pub fn new(
+        agg: &AggregateExec,
+        context: Arc<TaskContext>,
+        partition: usize,
+    ) -> Result<Self> {
+        debug!("Creating GroupedHashAggregateStream2");
+        let agg_schema = Arc::clone(&agg.schema);
+        let agg_group_by = agg.group_by.clone();
+        let agg_filter_expr = agg.filter_expr.clone();
+
+        let batch_size = context.session_config().batch_size();
+        let input = agg.input.execute(partition, Arc::clone(&context))?;
+        let baseline_metrics = BaselineMetrics::new(&agg.metrics, partition);
+
+        let timer = baseline_metrics.elapsed_compute().timer();
+
+        let mut aggregate_exprs = vec![];
+        let mut aggregate_arguments = vec![];
+
+        // The expressions to evaluate the batch, one vec of expressions per 
aggregation.
+        // Assuming create_schema() always puts group columns in front of 
aggregation columns, we set
+        // col_idx_base to the group expression count.
+
+        let all_aggregate_expressions = aggregates::aggregate_expressions(
+            &agg.aggr_expr,
+            &agg.mode,
+            agg_group_by.expr.len(),
+        )?;
+        let filter_expressions = match agg.mode {
+            AggregateMode::Partial | AggregateMode::Single => agg_filter_expr,
+            AggregateMode::Final | AggregateMode::FinalPartitioned => {
+                vec![None; agg.aggr_expr.len()]
+            }
+        };
+
+        for (agg_expr, agg_args) in agg
+            .aggr_expr
+            .iter()
+            .zip(all_aggregate_expressions.into_iter())
+        {
+            aggregate_exprs.push(agg_expr.clone());
+            aggregate_arguments.push(agg_args);
+        }
+
+        let accumulators = create_accumulators(aggregate_exprs)?;
+
+        let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+        let row_converter = RowConverter::new(
+            group_schema
+                .fields()
+                .iter()
+                .map(|f| SortField::new(f.data_type().clone()))
+                .collect(),
+        )?;
+
+        let name = format!("GroupedHashAggregateStream2[{partition}]");
+        let reservation = 
MemoryConsumer::new(name).register(context.memory_pool());
+        let map = RawTable::with_capacity(0);
+        let group_by_values = vec![];
+        let current_group_indices = vec![];
+
+        timer.done();
+
+        let exec_state = ExecutionState::ReadingInput;
+
+        Ok(GroupedHashAggregateStream2 {
+            schema: agg_schema,
+            input,
+            mode: agg.mode,
+            accumulators,
+            aggregate_arguments,
+            filter_expressions,
+            row_converter,
+            group_by: agg_group_by,
+            reservation,
+            map,
+            group_by_values,
+            current_group_indices,
+            exec_state,
+            baseline_metrics,
+            random_state: Default::default(),
+            batch_size,
+        })
+    }
+}
+
+/// Crate a `GroupsAccumulator` for each of the aggregate_exprs to hold the 
aggregation state
+fn create_accumulators(
+    aggregate_exprs: Vec<Arc<dyn AggregateExpr>>,
+) -> Result<Vec<Box<dyn GroupsAccumulator>>> {
+    debug!("Creating accumulator for {aggregate_exprs:#?}");
+    aggregate_exprs
+        .into_iter()
+        .map(|agg_expr| agg_expr.create_groups_accumulator())
+        .collect()
+}
+
+impl Stream for GroupedHashAggregateStream2 {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+
+        loop {
+            let exec_state = self.exec_state.clone();
+            match exec_state {
+                ExecutionState::ReadingInput => {
+                    match ready!(self.input.poll_next_unpin(cx)) {
+                        // new batch to aggregate
+                        Some(Ok(batch)) => {
+                            let timer = elapsed_compute.timer();
+                            let result = self.group_aggregate_batch(batch);
+                            timer.done();
+
+                            // allocate memory
+                            // This happens AFTER we actually used the memory, 
but simplifies the whole accounting and we are OK with
+                            // overshooting a bit. Also this means we either 
store the whole record batch or not.
+                            let result = result.and_then(|allocated| {
+                                self.reservation.try_grow(allocated)
+                            });
+
+                            if let Err(e) = result {
+                                return Poll::Ready(Some(Err(e)));
+                            }
+                        }
+                        // inner had error, return to caller
+                        Some(Err(e)) => return Poll::Ready(Some(Err(e))),
+                        // inner is done, producing output
+                        None => {
+                            let timer = elapsed_compute.timer();
+                            match self.create_batch_from_map() {
+                                Ok(batch) => {
+                                    self.exec_state =
+                                        ExecutionState::ProducingOutput(batch)
+                                }
+                                Err(e) => return Poll::Ready(Some(Err(e))),
+                            }
+                            timer.done();
+                        }
+                    }
+                }
+
+                ExecutionState::ProducingOutput(batch) => {
+                    // slice off a part of the batch, if needed
+                    let output_batch = if batch.num_rows() <= self.batch_size {
+                        self.exec_state = ExecutionState::Done;
+                        batch
+                    } else {
+                        // output first batch_size rows
+                        let num_remaining = batch.num_rows() - self.batch_size;
+                        let remaining = batch.slice(self.batch_size, 
num_remaining);
+                        self.exec_state = 
ExecutionState::ProducingOutput(remaining);
+                        batch.slice(0, self.batch_size)
+                    };
+                    return Poll::Ready(Some(Ok(
+                        output_batch.record_output(&self.baseline_metrics)
+                    )));
+                }
+
+                ExecutionState::Done => return Poll::Ready(None),
+            }
+        }
+    }
+}
+
+impl RecordBatchStream for GroupedHashAggregateStream2 {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+impl GroupedHashAggregateStream2 {
+    /// Update self.aggr_state based on the group_by values (result of 
evalauting the group_by_expressions)
+    ///
+    /// At the return of this function,
+    /// `self.aggr_state.current_group_indices` has the correct
+    /// group_index for each row in the group_values
+    fn update_group_state(
+        &mut self,
+        group_values: &[ArrayRef],
+        allocated: &mut usize,
+    ) -> Result<()> {
+        // Convert the group keys into the row format
+        let group_rows = self.row_converter.convert_columns(group_values)?;
+        let n_rows = group_rows.num_rows();
+        // 1.1 construct the key from the group values
+        // 1.2 construct the mapping key if it does not exist
+
+        // tracks to which group each of the input rows belongs
+        let group_indices = &mut self.current_group_indices;
+        group_indices.clear();
+
+        // 1.1 Calculate the group keys for the group values
+        let mut batch_hashes = vec![0; n_rows];
+        create_hashes(group_values, &self.random_state, &mut batch_hashes)?;
+
+        for (row, hash) in batch_hashes.into_iter().enumerate() {
+            let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
+                // verify that a group that we are inserting with hash is
+                // actually the same key value as the group in
+                // existing_idx  (aka group_values @ row)
+
+                // TODO update *allocated based on size of the row
+                // that was just pushed into
+                // aggr_state.group_by_values
+                group_rows.row(row) == self.group_by_values[*group_idx].row()
+            });
+
+            let group_idx = match entry {
+                // Existing group_index for this group value
+                Some((_hash, group_idx)) => *group_idx,
+                //  1.2 Need to create new entry for the group
+                None => {
+                    // Add new entry to aggr_state and save newly created index
+                    let group_idx = self.group_by_values.len();
+                    self.group_by_values.push(group_rows.row(row).owned());
+
+                    // for hasher function, use precomputed hash value
+                    self.map.insert_accounted(
+                        (hash, group_idx),
+                        |(hash, _group_index)| *hash,
+                        allocated,
+                    );
+                    group_idx
+                }
+            };
+            group_indices.push_accounted(group_idx, allocated);
+        }
+        Ok(())
+    }
+
+    /// Perform group-by aggregation for the given [`RecordBatch`].
+    ///
+    /// If successful, returns the additional amount of memory, in
+    /// bytes, that were allocated during this process.
+    ///
+    fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<usize> {
+        // Evaluate the grouping expressions:
+        let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
+
+        // Keep track of memory allocated:
+        let mut allocated = 0usize;
+
+        // Evaluate the aggregation expressions.
+        let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
+        // Evalaute the filter expressions, if any, against the inputs
+        let filter_values = evaluate_optional(&self.filter_expressions, 
&batch)?;
+
+        let row_converter_size_pre = self.row_converter.size();
+        for group_values in &group_by_values {
+            // calculate the group indicies for each input row
+            self.update_group_state(group_values, &mut allocated)?;
+            let group_indices = &self.current_group_indices;
+
+            // Gather the inputs to call the actual aggregation
+            let t = self
+                .accumulators
+                .iter_mut()
+                .zip(input_values.iter())
+                .zip(filter_values.iter());
+
+            let total_num_groups = self.group_by_values.len();
+
+            for ((acc, values), opt_filter) in t {
+                let acc_size_pre = acc.size();
+                let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean());
+
+                match self.mode {
+                    AggregateMode::Partial | AggregateMode::Single => {
+                        acc.update_batch(

Review Comment:
   Here is one key difference -- each accumulator is called *once* per input 
batch (not once per group)



##########
datafusion/physical-expr/src/aggregate/utils.rs:
##########
@@ -37,45 +37,107 @@ pub fn get_accum_scalar_values_as_arrays(
         .collect::<Vec<_>>())
 }
 
-pub fn calculate_result_decimal_for_avg(
-    lit_value: i128,
-    count: i128,
-    scale: i8,
-    target_type: &DataType,
-) -> Result<ScalarValue> {
-    match target_type {
-        DataType::Decimal128(p, s) => {
-            // Different precision for decimal128 can store different range of 
value.
-            // For example, the precision is 3, the max of value is `999` and 
the min
-            // value is `-999`
-            let (target_mul, target_min, target_max) = (
-                10_i128.pow(*s as u32),
-                MIN_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
-                MAX_DECIMAL_FOR_EACH_PRECISION[*p as usize - 1],
-            );
-            let lit_scale_mul = 10_i128.pow(scale as u32);
-            if target_mul >= lit_scale_mul {
-                if let Some(value) = lit_value.checked_mul(target_mul / 
lit_scale_mul) {
-                    let new_value = value / count;
-                    if new_value >= target_min && new_value <= target_max {
-                        Ok(ScalarValue::Decimal128(Some(new_value), *p, *s))
-                    } else {
-                        Err(DataFusionError::Execution(
-                            "Arithmetic Overflow in 
AvgAccumulator".to_string(),
-                        ))
-                    }
-                } else {
-                    // can't convert the lit decimal to the returned data type
-                    Err(DataFusionError::Execution(
-                        "Arithmetic Overflow in AvgAccumulator".to_string(),
-                    ))
-                }
+/// Computes averages for `Decimal128` values, checking for overflow

Review Comment:
   This is part of https://github.com/apache/arrow-datafusion/pull/6810



##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -155,6 +160,35 @@ impl AggregateExpr for Avg {
             &self.rt_data_type,
         )?))
     }
+
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        // instantiate specialized accumulator
+        match (&self.sum_data_type, &self.rt_data_type) {
+            (
+                DataType::Decimal128(_sum_precision, sum_scale),
+                DataType::Decimal128(target_precision, target_scale),
+            ) => {
+                let decimal_averager = Decimal128Averager::try_new(
+                    *sum_scale,
+                    *target_precision,
+                    *target_scale,
+                )?;
+
+                let avg_fn =

Review Comment:
   I am quite pleased that this code works for all the complexity of averages 
on decimals -- I feel that demonstrates the approach would work for other types 
and accumulators. 



##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -155,6 +160,35 @@ impl AggregateExpr for Avg {
             &self.rt_data_type,
         )?))
     }
+
+    fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
+        // instantiate specialized accumulator
+        match (&self.sum_data_type, &self.rt_data_type) {
+            (
+                DataType::Decimal128(_sum_precision, sum_scale),
+                DataType::Decimal128(target_precision, target_scale),
+            ) => {
+                let decimal_averager = Decimal128Averager::try_new(
+                    *sum_scale,
+                    *target_precision,
+                    *target_scale,
+                )?;
+
+                let avg_fn =
+                    move |sum: i128, count: u64| decimal_averager.avg(sum, 
count as i128);
+
+                Ok(Box::new(AvgGroupsAccumulator::<Decimal128Type, _>::new(

Review Comment:
   Here is a specialized accumulator  -- it will be instantiated once per 
native type or other type we need to support in the accumulator, but this will 
result in a specialized accumulator for each native type. 👨‍🍳 👌 



##########
datafusion/physical-expr/src/aggregate/groups_accumulator.rs:
##########
@@ -0,0 +1,100 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Vectorized [`GroupsAccumulator`]
+
+use arrow_array::{ArrayRef, BooleanArray};
+use datafusion_common::Result;
+
+/// An implementation of GroupAccumulator is for a single aggregate

Review Comment:
   Here is the new `GroupsAccumulator` trait that all accumulators would have 
to implement.
   
   I would also plan to create a struct that implements this trait for 
aggregates based on `Accumulator` s
   
   ```rust
   struct GroupsAdapter {
     groups: Vec<Box<dyn Accumulator>>
   }
   
   impl GroupsAccumulator for GroupsAdapter {
   ...
   }
   ```
   
   So in that way we can start with simpler (but slower) `Accumulator` 
implementations for aggregates, and provide a fast GroupsAccumulator for the 
aggregates / types that need the specialization



##########
datafusion/physical-expr/src/aggregate/average.rs:
##########
@@ -383,6 +417,234 @@ impl RowAccumulator for AvgRowAccumulator {
     }
 }
 
+/// An accumulator to compute the average of PrimitiveArray<T>.
+/// Stores values as native types, and does overflow checking
+#[derive(Debug)]
+struct AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    /// The type of the internal sum
+    sum_data_type: DataType,
+
+    /// The type of the returned sum
+    return_data_type: DataType,
+
+    /// Count per group (use u64 to make UInt64Array)
+    counts: Vec<u64>,
+
+    /// Sums per group, stored as the native type
+    sums: Vec<T::Native>,
+
+    /// Function that computes the average (value / count)
+    avg_fn: F,
+}
+
+impl<T, F> AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: 
F) -> Self {
+        debug!(
+            "AvgGroupsAccumulator ({}, sum type: {sum_data_type:?}) --> 
{return_data_type:?}",
+            std::any::type_name::<T>()
+        );
+        Self {
+            return_data_type: return_data_type.clone(),
+            sum_data_type: sum_data_type.clone(),
+            counts: vec![],
+            sums: vec![],
+            avg_fn,
+        }
+    }
+
+    /// Adds the values in `values` to self.sums
+    fn update_sums(
+        &mut self,
+        values: &PrimitiveArray<T>,
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.sums
+            .resize_with(total_num_groups, || T::default_value());
+
+        // AAL TODO
+        // TODO combine the null mask from values and opt_filter
+        let valids = values.nulls();
+
+        // This is based on (ahem, COPY/PASTA) arrow::compute::aggregate::sum
+        let data: &[T::Native] = values.values();
+
+        match valids {
+            // use all values in group_index
+            None => {
+                let iter = group_indicies.iter().zip(data.iter());
+                for (group_index, new_value) in iter {
+                    let sum = &mut self.sums[*group_index];
+                    *sum = sum.add_wrapping(*new_value);
+                }
+            }
+            //
+            Some(valids) => {
+                let group_indices_chunks = group_indicies.chunks_exact(64);
+                let data_chunks = data.chunks_exact(64);
+                let bit_chunks = valids.inner().bit_chunks();
+
+                let group_indices_remainder = group_indices_chunks.remainder();
+                let data_remainder = data_chunks.remainder();
+
+                group_indices_chunks
+                    .zip(data_chunks)
+                    .zip(bit_chunks.iter())
+                    .for_each(|((group_index_chunk, data_chunk), mask)| {
+                        // index_mask has value 1 << i in the loop
+                        let mut index_mask = 1;
+                        
group_index_chunk.iter().zip(data_chunk.iter()).for_each(
+                            |(group_index, new_value)| {
+                                if (mask & index_mask) != 0 {
+                                    let sum = &mut self.sums[*group_index];
+                                    *sum = sum.add_wrapping(*new_value);
+                                }
+                                index_mask <<= 1;
+                            },
+                        )
+                    });
+
+                let remainder_bits = bit_chunks.remainder_bits();
+                group_indices_remainder
+                    .iter()
+                    .zip(data_remainder.iter())
+                    .enumerate()
+                    .for_each(|(i, (group_index, new_value))| {
+                        if remainder_bits & (1 << i) != 0 {
+                            let sum = &mut self.sums[*group_index];
+                            *sum = sum.add_wrapping(*new_value);
+                        }
+                    });
+            }
+        }
+        Ok(())
+    }
+}
+
+impl<T, F> GroupsAccumulator for AvgGroupsAccumulator<T, F>
+where
+    T: ArrowNumericType + Send,
+    F: Fn(T::Native, u64) -> Result<T::Native> + Send,
+{
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 1, "single argument to update_batch");
+        let values = values.get(0).unwrap().as_primitive::<T>();
+
+        // update counts (TOD account for opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().for_each(|&group_idx| {
+            self.counts[group_idx] += 1;
+        });
+
+        // update values
+        self.update_sums(values, group_indicies, opt_filter, 
total_num_groups)?;
+
+        Ok(())
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indicies: &[usize],
+        opt_filter: Option<&arrow_array::BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        assert_eq!(values.len(), 2, "two arguments to merge_batch");
+        // first batch is counts, second is partial sums
+        let counts = values.get(0).unwrap().as_primitive::<UInt64Type>();
+        let partial_sums = values.get(1).unwrap().as_primitive::<T>();
+
+        // update counts by summing the partial sums (TODO account for 
opt_filter)
+        self.counts.resize(total_num_groups, 0);
+        group_indicies.iter().zip(counts.values().iter()).for_each(
+            |(&group_idx, &count)| {
+                self.counts[group_idx] += count;
+            },
+        );
+
+        // update values
+        self.update_sums(partial_sums, group_indicies, opt_filter, 
total_num_groups)?;
+
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> Result<ArrayRef> {
+        let counts = std::mem::take(&mut self.counts);
+        let sums = std::mem::take(&mut self.sums);
+
+        let averages: Vec<T::Native> = sums
+            .into_iter()
+            .zip(counts.into_iter())
+            .map(|(sum, count)| (self.avg_fn)(sum, count))
+            .collect::<Result<Vec<_>>>()?;
+
+        // TODO figure out how to do this without the iter / copy

Review Comment:
   This probably copies values once (rather than reusing the Vec) -- maybe 
@tustvold can help me figure out the generic magic to make a `PrimtiveArray<T>` 
from a `Vec<T::Native>` without a copy



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to