This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 96678875c avoid group-key/-state clones for hash-groupby (#4651)
96678875c is described below

commit 96678875c8f310691c2c381a6f48d06aa8dbfccb
Author: Marco Neumann <[email protected]>
AuthorDate: Fri Dec 16 09:58:17 2022 +0000

    avoid group-key/-state clones for hash-groupby (#4651)
---
 .../core/src/physical_plan/aggregates/hash.rs      | 51 +++++++++++++++-------
 1 file changed, 35 insertions(+), 16 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs 
b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 0d35f5b0d..4d1933080 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -17,6 +17,7 @@
 
 //! Defines the execution plan for the hash aggregate operation
 
+use std::collections::VecDeque;
 use std::sync::Arc;
 use std::task::{Context, Poll};
 use std::vec;
@@ -91,7 +92,7 @@ struct GroupedHashAggregateStreamInner {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
     mode: AggregateMode,
-    accumulators: Accumulators,
+    accumulators: Option<Accumulators>,
     aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
 
     aggr_expr: Vec<Arc<dyn AggregateExpr>>,
@@ -133,7 +134,7 @@ impl GroupedHashAggregateStream {
             group_by,
             baseline_metrics,
             aggregate_expressions,
-            accumulators: Accumulators {
+            accumulators: Some(Accumulators {
                 memory_consumer: MemoryConsumerProxy::new(
                     "GroupBy Hash Accumulators",
                     MemoryConsumerId::new(partition),
@@ -141,7 +142,7 @@ impl GroupedHashAggregateStream {
                 ),
                 map: RawTable::with_capacity(0),
                 group_states: Vec::with_capacity(0),
-            },
+            }),
             random_state: Default::default(),
             finished: false,
         };
@@ -157,13 +158,15 @@ impl GroupedHashAggregateStream {
                 let result = match this.input.next().await {
                     Some(Ok(batch)) => {
                         let timer = elapsed_compute.timer();
+                        let accumulators =
+                            this.accumulators.as_mut().expect("not yet 
finished");
                         let result = group_aggregate_batch(
                             &this.mode,
                             &this.random_state,
                             &this.group_by,
                             &this.aggr_expr,
                             batch,
-                            &mut this.accumulators,
+                            accumulators,
                             &this.aggregate_expressions,
                         );
 
@@ -174,7 +177,7 @@ impl GroupedHashAggregateStream {
                         // overshooting a bit. Also this means we either store 
the whole record batch or not.
                         let result = match result {
                             Ok(allocated) => {
-                                
this.accumulators.memory_consumer.alloc(allocated).await
+                                
accumulators.memory_consumer.alloc(allocated).await
                             }
                             Err(e) => Err(e),
                         };
@@ -190,7 +193,8 @@ impl GroupedHashAggregateStream {
                         let timer = 
this.baseline_metrics.elapsed_compute().timer();
                         let result = create_batch_from_map(
                             &this.mode,
-                            &this.accumulators,
+                            std::mem::take(&mut this.accumulators)
+                                .expect("not yet finished"),
                             this.group_by.expr.len(),
                             &this.schema,
                         )
@@ -475,7 +479,7 @@ impl std::fmt::Debug for Accumulators {
 /// ```
 fn create_batch_from_map(
     mode: &AggregateMode,
-    accumulators: &Accumulators,
+    accumulators: Accumulators,
     num_group_expr: usize,
     output_schema: &Schema,
 ) -> ArrowResult<RecordBatch> {
@@ -498,14 +502,26 @@ fn create_batch_from_map(
         }
     }
 
+    // make group states mutable
+    let (mut group_by_values_vec, mut accumulator_set_vec): (Vec<_>, Vec<_>) =
+        accumulators
+            .group_states
+            .into_iter()
+            .map(|group_state| {
+                (
+                    VecDeque::from(group_state.group_by_values.to_vec()),
+                    VecDeque::from(group_state.accumulator_set),
+                )
+            })
+            .unzip();
+
     // First, output all group by exprs
     let mut columns = (0..num_group_expr)
-        .map(|i| {
+        .map(|_| {
             ScalarValue::iter_to_array(
-                accumulators
-                    .group_states
-                    .iter()
-                    .map(|group_state| group_state.group_by_values[i].clone()),
+                group_by_values_vec
+                    .iter_mut()
+                    .map(|x| x.pop_front().expect("invalid group_by_values")),
             )
         })
         .collect::<Result<Vec<_>>>()?;
@@ -516,8 +532,8 @@ fn create_batch_from_map(
             match mode {
                 AggregateMode::Partial => {
                     let res = ScalarValue::iter_to_array(
-                        accumulators.group_states.iter().map(|group_state| {
-                            group_state.accumulator_set[x]
+                        accumulator_set_vec.iter().map(|accumulator_set| {
+                            accumulator_set[x]
                                 .state()
                                 .map(|x| x[y].clone())
                                 .expect("unexpected accumulator state in hash 
aggregate")
@@ -528,8 +544,11 @@ fn create_batch_from_map(
                 }
                 AggregateMode::Final | AggregateMode::FinalPartitioned => {
                     let res = ScalarValue::iter_to_array(
-                        accumulators.group_states.iter().map(|group_state| {
-                            group_state.accumulator_set[x].evaluate().unwrap()
+                        accumulator_set_vec.iter_mut().map(|x| {
+                            x.pop_front()
+                                .expect("invalid accumulator_set")
+                                .evaluate()
+                                .unwrap()
                         }),
                     )?;
                     columns.push(res);

Reply via email to