This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch hash_agg_spike
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 82c9205200c0049aae7bfa07815cfcad2e7ffc3c
Author: Andrew Lamb <[email protected]>
AuthorDate: Sat Jul 1 05:07:56 2023 -0400

    Update comments and simplify code
---
 .../core/src/physical_plan/aggregates/row_hash2.rs | 54 +++++++++-------------
 1 file changed, 21 insertions(+), 33 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs
index 3e9dbfe0cf..792fbb4032 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs
@@ -36,7 +36,7 @@ use crate::physical_plan::aggregates::{
     PhysicalGroupBy,
 };
 use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
-use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
+use crate::physical_plan::{aggregates, PhysicalExpr};
 use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
 use arrow::array::*;
 use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
@@ -204,12 +204,11 @@ impl GroupedHashAggregateStream2 {
 
         let timer = baseline_metrics.elapsed_compute().timer();
 
-        let mut aggregate_exprs = vec![];
-        let mut aggregate_arguments = vec![];
+        let aggregate_exprs = agg.aggr_expr.clone();
 
-        // The arguments for each aggregate, one vec of expressions
-        // per aggregation.
-        let all_aggregate_expressions = aggregates::aggregate_expressions(
+        // arguments for each aggregate, one vec of expressions per
+        // aggregate
+        let aggregate_arguments = aggregates::aggregate_expressions(
             &agg.aggr_expr,
             &agg.mode,
             agg_group_by.expr.len(),
@@ -222,16 +221,11 @@ impl GroupedHashAggregateStream2 {
             }
         };
 
-        for (agg_expr, agg_args) in agg
-            .aggr_expr
+        // Instantiate the accumulators
+        let accumulators: Vec<_> = aggregate_exprs
             .iter()
-            .zip(all_aggregate_expressions.into_iter())
-        {
-            aggregate_exprs.push(agg_expr.clone());
-            aggregate_arguments.push(agg_args);
-        }
-
-        let accumulators = create_accumulators(aggregate_exprs)?;
+            .map(|agg_expr| agg_expr.create_groups_accumulator())
+            .collect::<Result<_>>()?;
 
         let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
         let row_converter = RowConverter::new(
@@ -273,18 +267,6 @@ impl GroupedHashAggregateStream2 {
     }
 }
 
-/// Crate a [`GroupsAccumulator`] for each of the aggregate_exprs to
-/// hold the aggregation state
-fn create_accumulators(
-    aggregate_exprs: Vec<Arc<dyn AggregateExpr>>,
-) -> Result<Vec<Box<dyn GroupsAccumulator>>> {
-    debug!("Creating accumulator for {aggregate_exprs:#?}");
-    aggregate_exprs
-        .into_iter()
-        .map(|agg_expr| agg_expr.create_groups_accumulator())
-        .collect()
-}
-
 impl Stream for GroupedHashAggregateStream2 {
     type Item = Result<RecordBatch>;
 
@@ -363,11 +345,13 @@ impl RecordBatchStream for GroupedHashAggregateStream2 {
 }
 
 impl GroupedHashAggregateStream2 {
-    /// Update self.aggr_state based on the group_by values (result of 
evalauting the group_by_expressions)
+    /// Calculates the group indicies for each input row of
+    /// `group_values`.
     ///
     /// At the return of this function,
-    /// `self.aggr_state.current_group_indices` has the correct
-    /// group_index for each row in the group_values
+    /// [`Self::current_group_indicies`] has the same number of
+    /// entries as each array in `group_values` and holds the correct
+    /// group_index for that row.
     fn update_group_state(
         &mut self,
         group_values: &[ArrayRef],
@@ -376,6 +360,7 @@ impl GroupedHashAggregateStream2 {
         // Convert the group keys into the row format
         let group_rows = self.row_converter.convert_columns(group_values)?;
         let n_rows = group_rows.num_rows();
+
         // 1.1 construct the key from the group values
         // 1.2 construct the mapping key if it does not exist
 
@@ -426,9 +411,8 @@ impl GroupedHashAggregateStream2 {
     ///
     /// If successful, returns the additional amount of memory, in
     /// bytes, that were allocated during this process.
-    ///
     fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<usize> {
-        // Evaluate the grouping expressions:
+        // Evaluate the grouping expressions
         let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
 
         // Keep track of memory allocated:
@@ -436,10 +420,12 @@ impl GroupedHashAggregateStream2 {
 
         // Evaluate the aggregation expressions.
         let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
-        // Evalaute the filter expressions, if any, against the inputs
+
+        // Evalute the filter expressions, if any, against the inputs
         let filter_values = evaluate_optional(&self.filter_expressions, 
&batch)?;
 
         let row_converter_size_pre = self.row_converter.size();
+
         for group_values in &group_by_values {
             // calculate the group indicies for each input row
             self.update_group_state(group_values, &mut allocated)?;
@@ -458,6 +444,8 @@ impl GroupedHashAggregateStream2 {
                 let acc_size_pre = acc.size();
                 let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean());
 
+                // Call the appropriate method on each aggregator with
+                // the entire input row and the relevant group indexes
                 match self.mode {
                     AggregateMode::Partial | AggregateMode::Single => {
                         acc.update_batch(

Reply via email to