This is an automated email from the ASF dual-hosted git repository. dheres pushed a commit to branch hash_agg_spike in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
commit 82c9205200c0049aae7bfa07815cfcad2e7ffc3c Author: Andrew Lamb <[email protected]> AuthorDate: Sat Jul 1 05:07:56 2023 -0400 Update comments and simplify code --- .../core/src/physical_plan/aggregates/row_hash2.rs | 54 +++++++++------------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs index 3e9dbfe0cf..792fbb4032 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash2.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash2.rs @@ -36,7 +36,7 @@ use crate::physical_plan::aggregates::{ PhysicalGroupBy, }; use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput}; -use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr}; +use crate::physical_plan::{aggregates, PhysicalExpr}; use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream}; use arrow::array::*; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; @@ -204,12 +204,11 @@ impl GroupedHashAggregateStream2 { let timer = baseline_metrics.elapsed_compute().timer(); - let mut aggregate_exprs = vec![]; - let mut aggregate_arguments = vec![]; + let aggregate_exprs = agg.aggr_expr.clone(); - // The arguments for each aggregate, one vec of expressions - // per aggregation. - let all_aggregate_expressions = aggregates::aggregate_expressions( + // arguments for each aggregate, one vec of expressions per + // aggregate + let aggregate_arguments = aggregates::aggregate_expressions( &agg.aggr_expr, &agg.mode, agg_group_by.expr.len(), @@ -222,16 +221,11 @@ impl GroupedHashAggregateStream2 { } }; - for (agg_expr, agg_args) in agg - .aggr_expr + // Instantiate the accumulators + let accumulators: Vec<_> = aggregate_exprs .iter() - .zip(all_aggregate_expressions.into_iter()) - { - aggregate_exprs.push(agg_expr.clone()); - aggregate_arguments.push(agg_args); - } - - let accumulators = create_accumulators(aggregate_exprs)?; + .map(|agg_expr| agg_expr.create_groups_accumulator()) + .collect::<Result<_>>()?; let group_schema = group_schema(&agg_schema, agg_group_by.expr.len()); let row_converter = RowConverter::new( @@ -273,18 +267,6 @@ impl GroupedHashAggregateStream2 { } } -/// Crate a [`GroupsAccumulator`] for each of the aggregate_exprs to -/// hold the aggregation state -fn create_accumulators( - aggregate_exprs: Vec<Arc<dyn AggregateExpr>>, -) -> Result<Vec<Box<dyn GroupsAccumulator>>> { - debug!("Creating accumulator for {aggregate_exprs:#?}"); - aggregate_exprs - .into_iter() - .map(|agg_expr| agg_expr.create_groups_accumulator()) - .collect() -} - impl Stream for GroupedHashAggregateStream2 { type Item = Result<RecordBatch>; @@ -363,11 +345,13 @@ impl RecordBatchStream for GroupedHashAggregateStream2 { } impl GroupedHashAggregateStream2 { - /// Update self.aggr_state based on the group_by values (result of evalauting the group_by_expressions) + /// Calculates the group indicies for each input row of + /// `group_values`. /// /// At the return of this function, - /// `self.aggr_state.current_group_indices` has the correct - /// group_index for each row in the group_values + /// [`Self::current_group_indicies`] has the same number of + /// entries as each array in `group_values` and holds the correct + /// group_index for that row. fn update_group_state( &mut self, group_values: &[ArrayRef], @@ -376,6 +360,7 @@ impl GroupedHashAggregateStream2 { // Convert the group keys into the row format let group_rows = self.row_converter.convert_columns(group_values)?; let n_rows = group_rows.num_rows(); + // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist @@ -426,9 +411,8 @@ impl GroupedHashAggregateStream2 { /// /// If successful, returns the additional amount of memory, in /// bytes, that were allocated during this process. - /// fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<usize> { - // Evaluate the grouping expressions: + // Evaluate the grouping expressions let group_by_values = evaluate_group_by(&self.group_by, &batch)?; // Keep track of memory allocated: @@ -436,10 +420,12 @@ impl GroupedHashAggregateStream2 { // Evaluate the aggregation expressions. let input_values = evaluate_many(&self.aggregate_arguments, &batch)?; - // Evalaute the filter expressions, if any, against the inputs + + // Evalute the filter expressions, if any, against the inputs let filter_values = evaluate_optional(&self.filter_expressions, &batch)?; let row_converter_size_pre = self.row_converter.size(); + for group_values in &group_by_values { // calculate the group indicies for each input row self.update_group_state(group_values, &mut allocated)?; @@ -458,6 +444,8 @@ impl GroupedHashAggregateStream2 { let acc_size_pre = acc.size(); let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean()); + // Call the appropriate method on each aggregator with + // the entire input row and the relevant group indexes match self.mode { AggregateMode::Partial | AggregateMode::Single => { acc.update_batch(
