This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 96678875c avoid group-key/-state clones for hash-groupby (#4651)
96678875c is described below
commit 96678875c8f310691c2c381a6f48d06aa8dbfccb
Author: Marco Neumann <[email protected]>
AuthorDate: Fri Dec 16 09:58:17 2022 +0000
avoid group-key/-state clones for hash-groupby (#4651)
---
.../core/src/physical_plan/aggregates/hash.rs | 51 +++++++++++++++-------
1 file changed, 35 insertions(+), 16 deletions(-)
diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs
b/datafusion/core/src/physical_plan/aggregates/hash.rs
index 0d35f5b0d..4d1933080 100644
--- a/datafusion/core/src/physical_plan/aggregates/hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/hash.rs
@@ -17,6 +17,7 @@
//! Defines the execution plan for the hash aggregate operation
+use std::collections::VecDeque;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
@@ -91,7 +92,7 @@ struct GroupedHashAggregateStreamInner {
schema: SchemaRef,
input: SendableRecordBatchStream,
mode: AggregateMode,
- accumulators: Accumulators,
+ accumulators: Option<Accumulators>,
aggregate_expressions: Vec<Vec<Arc<dyn PhysicalExpr>>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
@@ -133,7 +134,7 @@ impl GroupedHashAggregateStream {
group_by,
baseline_metrics,
aggregate_expressions,
- accumulators: Accumulators {
+ accumulators: Some(Accumulators {
memory_consumer: MemoryConsumerProxy::new(
"GroupBy Hash Accumulators",
MemoryConsumerId::new(partition),
@@ -141,7 +142,7 @@ impl GroupedHashAggregateStream {
),
map: RawTable::with_capacity(0),
group_states: Vec::with_capacity(0),
- },
+ }),
random_state: Default::default(),
finished: false,
};
@@ -157,13 +158,15 @@ impl GroupedHashAggregateStream {
let result = match this.input.next().await {
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
+ let accumulators =
+ this.accumulators.as_mut().expect("not yet
finished");
let result = group_aggregate_batch(
&this.mode,
&this.random_state,
&this.group_by,
&this.aggr_expr,
batch,
- &mut this.accumulators,
+ accumulators,
&this.aggregate_expressions,
);
@@ -174,7 +177,7 @@ impl GroupedHashAggregateStream {
// overshooting a bit. Also this means we either store
the whole record batch or not.
let result = match result {
Ok(allocated) => {
-
this.accumulators.memory_consumer.alloc(allocated).await
+
accumulators.memory_consumer.alloc(allocated).await
}
Err(e) => Err(e),
};
@@ -190,7 +193,8 @@ impl GroupedHashAggregateStream {
let timer =
this.baseline_metrics.elapsed_compute().timer();
let result = create_batch_from_map(
&this.mode,
- &this.accumulators,
+ std::mem::take(&mut this.accumulators)
+ .expect("not yet finished"),
this.group_by.expr.len(),
&this.schema,
)
@@ -475,7 +479,7 @@ impl std::fmt::Debug for Accumulators {
/// ```
fn create_batch_from_map(
mode: &AggregateMode,
- accumulators: &Accumulators,
+ accumulators: Accumulators,
num_group_expr: usize,
output_schema: &Schema,
) -> ArrowResult<RecordBatch> {
@@ -498,14 +502,26 @@ fn create_batch_from_map(
}
}
+ // make group states mutable
+ let (mut group_by_values_vec, mut accumulator_set_vec): (Vec<_>, Vec<_>) =
+ accumulators
+ .group_states
+ .into_iter()
+ .map(|group_state| {
+ (
+ VecDeque::from(group_state.group_by_values.to_vec()),
+ VecDeque::from(group_state.accumulator_set),
+ )
+ })
+ .unzip();
+
// First, output all group by exprs
let mut columns = (0..num_group_expr)
- .map(|i| {
+ .map(|_| {
ScalarValue::iter_to_array(
- accumulators
- .group_states
- .iter()
- .map(|group_state| group_state.group_by_values[i].clone()),
+ group_by_values_vec
+ .iter_mut()
+ .map(|x| x.pop_front().expect("invalid group_by_values")),
)
})
.collect::<Result<Vec<_>>>()?;
@@ -516,8 +532,8 @@ fn create_batch_from_map(
match mode {
AggregateMode::Partial => {
let res = ScalarValue::iter_to_array(
- accumulators.group_states.iter().map(|group_state| {
- group_state.accumulator_set[x]
+ accumulator_set_vec.iter().map(|accumulator_set| {
+ accumulator_set[x]
.state()
.map(|x| x[y].clone())
.expect("unexpected accumulator state in hash
aggregate")
@@ -528,8 +544,11 @@ fn create_batch_from_map(
}
AggregateMode::Final | AggregateMode::FinalPartitioned => {
let res = ScalarValue::iter_to_array(
- accumulators.group_states.iter().map(|group_state| {
- group_state.accumulator_set[x].evaluate().unwrap()
+ accumulator_set_vec.iter_mut().map(|x| {
+ x.pop_front()
+ .expect("invalid accumulator_set")
+ .evaluate()
+ .unwrap()
}),
)?;
columns.push(res);