This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new f1f8d79ab8 feat: Support spilling for hash aggregation (#7400)
f1f8d79ab8 is described below
commit f1f8d79ab80039ca531e0b20cd4ac4381a073713
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Fri Sep 15 04:11:43 2023 -0700
feat: Support spilling for hash aggregation (#7400)
* Support spilling for hash aggregation
* clippy
* address review comments
* address review comments
* address review comments
* address review comments
* address review comments
* address review comments
* address review comments
* address review comments
* address review comments
* a
* address review comments
* address review comments
* address review comments
* address review comments
---
.../physical_plan/aggregates/group_values/mod.rs | 4 +
.../aggregates/group_values/primitive.rs | 9 +
.../physical_plan/aggregates/group_values/row.rs | 12 +
.../core/src/physical_plan/aggregates/mod.rs | 195 +++++++++++---
.../src/physical_plan/aggregates/order/partial.rs | 2 +-
.../core/src/physical_plan/aggregates/row_hash.rs | 299 +++++++++++++++++++--
datafusion/core/src/physical_plan/sorts/sort.rs | 4 +-
.../physical-expr/src/aggregate/first_last.rs | 16 +-
8 files changed, 482 insertions(+), 59 deletions(-)
diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
index f10f83dfe3..cafa385eac 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
@@ -15,6 +15,7 @@
// specific language governing permissions and limitations
// under the License.
+use arrow::record_batch::RecordBatch;
use arrow_array::{downcast_primitive, ArrayRef};
use arrow_schema::SchemaRef;
use datafusion_common::Result;
@@ -42,6 +43,9 @@ pub trait GroupValues: Send {
/// Emits the group values
fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
+
+ /// Clear the contents and shrink the capacity to the size of the batch
(free up memory usage)
+ fn clear_shrink(&mut self, batch: &RecordBatch);
}
pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
diff --git
a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
index d7989fb8c4..7a52729d20 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
@@ -20,6 +20,7 @@ use ahash::RandomState;
use arrow::array::BooleanBufferBuilder;
use arrow::buffer::NullBuffer;
use arrow::datatypes::i256;
+use arrow::record_batch::RecordBatch;
use arrow_array::cast::AsArray;
use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType,
PrimitiveArray};
use arrow_schema::DataType;
@@ -206,4 +207,12 @@ where
};
Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
}
+
+ fn clear_shrink(&mut self, batch: &RecordBatch) {
+ let count = batch.num_rows();
+ self.values.clear();
+ self.values.shrink_to(count);
+ self.map.clear();
+ self.map.shrink_to(count, |_| 0); // hasher does not matter since the
map is cleared
+ }
}
diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
index 4eb660d525..d711a16191 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
@@ -17,6 +17,7 @@
use crate::physical_plan::aggregates::group_values::GroupValues;
use ahash::RandomState;
+use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, Rows, SortField};
use arrow_array::ArrayRef;
use arrow_schema::SchemaRef;
@@ -181,4 +182,15 @@ impl GroupValues for GroupValuesRows {
}
})
}
+
+ fn clear_shrink(&mut self, batch: &RecordBatch) {
+ let count = batch.num_rows();
+ // FIXME: there is no good way to clear_shrink for self.group_values
+ self.group_values = self.row_converter.empty_rows(count, 0);
+ self.map.clear();
+ self.map.shrink_to(count, |_| 0); // hasher does not matter since the
map is cleared
+ self.map_size = self.map.capacity() * std::mem::size_of::<(u64,
usize)>();
+ self.hashes_buffer.clear();
+ self.hashes_buffer.shrink_to(count);
+ }
}
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index bb3f1edfa8..bbc2b949e2 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -1296,6 +1296,7 @@ mod tests {
use std::sync::Arc;
use std::task::{Context, Poll};
+ use datafusion_execution::config::SessionConfig;
use futures::{FutureExt, Stream};
// Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -1466,7 +1467,22 @@ mod tests {
)
}
- async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+ fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext>
{
+ let session_config = SessionConfig::new().with_batch_size(batch_size);
+ let runtime = Arc::new(
+
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0))
+ .unwrap(),
+ );
+ let task_ctx = TaskContext::default()
+ .with_session_config(session_config)
+ .with_runtime(runtime);
+ Arc::new(task_ctx)
+ }
+
+ async fn check_grouping_sets(
+ input: Arc<dyn ExecutionPlan>,
+ spill: bool,
+ ) -> Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy {
@@ -1491,7 +1507,11 @@ mod tests {
DataType::Int64,
))];
- let task_ctx = Arc::new(TaskContext::default());
+ let task_ctx = if spill {
+ new_spill_ctx(4, 1000)
+ } else {
+ Arc::new(TaskContext::default())
+ };
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
@@ -1506,24 +1526,53 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0,
task_ctx.clone())?).await?;
- let expected = vec![
- "+---+-----+-----------------+",
- "| a | b | COUNT(1)[count] |",
- "+---+-----+-----------------+",
- "| | 1.0 | 2 |",
- "| | 2.0 | 2 |",
- "| | 3.0 | 2 |",
- "| | 4.0 | 2 |",
- "| 2 | | 2 |",
- "| 2 | 1.0 | 2 |",
- "| 3 | | 3 |",
- "| 3 | 2.0 | 2 |",
- "| 3 | 3.0 | 1 |",
- "| 4 | | 3 |",
- "| 4 | 3.0 | 1 |",
- "| 4 | 4.0 | 2 |",
- "+---+-----+-----------------+",
- ];
+ let expected = if spill {
+ vec![
+ "+---+-----+-----------------+",
+ "| a | b | COUNT(1)[count] |",
+ "+---+-----+-----------------+",
+ "| | 1.0 | 1 |",
+ "| | 1.0 | 1 |",
+ "| | 2.0 | 1 |",
+ "| | 2.0 | 1 |",
+ "| | 3.0 | 1 |",
+ "| | 3.0 | 1 |",
+ "| | 4.0 | 1 |",
+ "| | 4.0 | 1 |",
+ "| 2 | | 1 |",
+ "| 2 | | 1 |",
+ "| 2 | 1.0 | 1 |",
+ "| 2 | 1.0 | 1 |",
+ "| 3 | | 1 |",
+ "| 3 | | 2 |",
+ "| 3 | 2.0 | 2 |",
+ "| 3 | 3.0 | 1 |",
+ "| 4 | | 1 |",
+ "| 4 | | 2 |",
+ "| 4 | 3.0 | 1 |",
+ "| 4 | 4.0 | 2 |",
+ "+---+-----+-----------------+",
+ ]
+ } else {
+ vec![
+ "+---+-----+-----------------+",
+ "| a | b | COUNT(1)[count] |",
+ "+---+-----+-----------------+",
+ "| | 1.0 | 2 |",
+ "| | 2.0 | 2 |",
+ "| | 3.0 | 2 |",
+ "| | 4.0 | 2 |",
+ "| 2 | | 2 |",
+ "| 2 | 1.0 | 2 |",
+ "| 3 | | 3 |",
+ "| 3 | 2.0 | 2 |",
+ "| 3 | 3.0 | 1 |",
+ "| 4 | | 3 |",
+ "| 4 | 3.0 | 1 |",
+ "| 4 | 4.0 | 2 |",
+ "+---+-----+-----------------+",
+ ]
+ };
assert_batches_sorted_eq!(expected, &result);
let groups = partial_aggregate.group_expr().expr().to_vec();
@@ -1537,6 +1586,12 @@ mod tests {
let final_grouping_set = PhysicalGroupBy::new_single(final_group);
+ let task_ctx = if spill {
+ new_spill_ctx(4, 3160)
+ } else {
+ task_ctx
+ };
+
let merged_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Final,
final_grouping_set,
@@ -1582,7 +1637,7 @@ mod tests {
}
/// build the aggregates on the data from some_data() and check the results
- async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+ async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) ->
Result<()> {
let input_schema = input.schema();
let grouping_set = PhysicalGroupBy {
@@ -1597,7 +1652,11 @@ mod tests {
DataType::Float64,
))];
- let task_ctx = Arc::new(TaskContext::default());
+ let task_ctx = if spill {
+ new_spill_ctx(2, 2144)
+ } else {
+ Arc::new(TaskContext::default())
+ };
let partial_aggregate = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
@@ -1612,15 +1671,29 @@ mod tests {
let result =
common::collect(partial_aggregate.execute(0,
task_ctx.clone())?).await?;
- let expected = [
- "+---+---------------+-------------+",
- "| a | AVG(b)[count] | AVG(b)[sum] |",
- "+---+---------------+-------------+",
- "| 2 | 2 | 2.0 |",
- "| 3 | 3 | 7.0 |",
- "| 4 | 3 | 11.0 |",
- "+---+---------------+-------------+",
- ];
+ let expected = if spill {
+ vec![
+ "+---+---------------+-------------+",
+ "| a | AVG(b)[count] | AVG(b)[sum] |",
+ "+---+---------------+-------------+",
+ "| 2 | 1 | 1.0 |",
+ "| 2 | 1 | 1.0 |",
+ "| 3 | 1 | 2.0 |",
+ "| 3 | 2 | 5.0 |",
+ "| 4 | 3 | 11.0 |",
+ "+---+---------------+-------------+",
+ ]
+ } else {
+ vec![
+ "+---+---------------+-------------+",
+ "| a | AVG(b)[count] | AVG(b)[sum] |",
+ "+---+---------------+-------------+",
+ "| 2 | 2 | 2.0 |",
+ "| 3 | 3 | 7.0 |",
+ "| 4 | 3 | 11.0 |",
+ "+---+---------------+-------------+",
+ ]
+ };
assert_batches_sorted_eq!(expected, &result);
let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
@@ -1663,7 +1736,13 @@ mod tests {
let metrics = merged_aggregate.metrics().unwrap();
let output_rows = metrics.output_rows().unwrap();
- assert_eq!(3, output_rows);
+ if spill {
+ // When spilling, the output rows metrics become partial output
size + final output size
+ // This is because final aggregation starts while partial
aggregation is still emitting
+ assert_eq!(8, output_rows);
+ } else {
+ assert_eq!(3, output_rows);
+ }
Ok(())
}
@@ -1784,7 +1863,7 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });
- check_aggregates(input).await
+ check_aggregates(input, false).await
}
#[tokio::test]
@@ -1792,7 +1871,7 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: false });
- check_grouping_sets(input).await
+ check_grouping_sets(input, false).await
}
#[tokio::test]
@@ -1800,7 +1879,7 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
- check_aggregates(input).await
+ check_aggregates(input, false).await
}
#[tokio::test]
@@ -1808,7 +1887,39 @@ mod tests {
let input: Arc<dyn ExecutionPlan> =
Arc::new(TestYieldingExec { yield_first: true });
- check_grouping_sets(input).await
+ check_grouping_sets(input, false).await
+ }
+
+ #[tokio::test]
+ async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: false });
+
+ check_aggregates(input, true).await
+ }
+
+ #[tokio::test]
+ async fn aggregate_grouping_sets_source_not_yielding_with_spill() ->
Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: false });
+
+ check_grouping_sets(input, true).await
+ }
+
+ #[tokio::test]
+ async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: true });
+
+ check_aggregates(input, true).await
+ }
+
+ #[tokio::test]
+ async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
+ let input: Arc<dyn ExecutionPlan> =
+ Arc::new(TestYieldingExec { yield_first: true });
+
+ check_grouping_sets(input, true).await
}
#[tokio::test]
@@ -1976,7 +2087,10 @@ mod tests {
async fn run_first_last_multi_partitions() -> Result<()> {
for use_coalesce_batches in [false, true] {
for is_first_acc in [false, true] {
- first_last_multi_partitions(use_coalesce_batches,
is_first_acc).await?
+ for spill in [false, true] {
+ first_last_multi_partitions(use_coalesce_batches,
is_first_acc, spill)
+ .await?
+ }
}
}
Ok(())
@@ -2002,8 +2116,13 @@ mod tests {
async fn first_last_multi_partitions(
use_coalesce_batches: bool,
is_first_acc: bool,
+ spill: bool,
) -> Result<()> {
- let task_ctx = Arc::new(TaskContext::default());
+ let task_ctx = if spill {
+ new_spill_ctx(2, 2812)
+ } else {
+ Arc::new(TaskContext::default())
+ };
let (schema, data) = some_data_v2();
let partition1 = data[0].clone();
diff --git a/datafusion/core/src/physical_plan/aggregates/order/partial.rs
b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
index 019e61ef26..0feac3a5ed 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/partial.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
@@ -241,7 +241,7 @@ impl GroupOrderingPartial {
Ok(())
}
- /// Return the size of memor allocated by this structure
+ /// Return the size of memory allocated by this structure
pub(crate) fn size(&self) -> usize {
std::mem::size_of::<Self>()
+ self.order_indices.allocated_size()
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index d034bd669e..eef25c1dc2 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -18,7 +18,7 @@
//! Hash aggregation
use datafusion_physical_expr::{
- AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter,
+ AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter,
PhysicalSortExpr,
};
use log::debug;
use std::sync::Arc;
@@ -29,19 +29,28 @@ use futures::ready;
use futures::stream::{Stream, StreamExt};
use crate::physical_plan::aggregates::group_values::{new_group_values,
GroupValues};
+use crate::physical_plan::aggregates::order::GroupOrderingFull;
use crate::physical_plan::aggregates::{
evaluate_group_by, evaluate_many, evaluate_optional, group_schema,
AggregateMode,
PhysicalGroupBy,
};
+use crate::physical_plan::common::IPCWriter;
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
+use crate::physical_plan::sorts::sort::{read_spill_as_stream, sort_batch};
+use crate::physical_plan::sorts::streaming_merge;
+use crate::physical_plan::stream::RecordBatchStreamAdapter;
use crate::physical_plan::{aggregates, PhysicalExpr};
use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
use arrow::array::*;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
-use datafusion_common::Result;
+use arrow_schema::SortOptions;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
+use datafusion_physical_expr::expressions::col;
#[derive(Debug, Clone)]
/// This object tracks the aggregation phase (input/output)
@@ -56,6 +65,28 @@ pub(crate) enum ExecutionState {
use super::order::GroupOrdering;
use super::AggregateExec;
+/// This encapsulates the spilling state
+struct SpillState {
+ /// If data has previously been spilled, the locations of the
+ /// spill files (in Arrow IPC format)
+ spills: Vec<RefCountedTempFile>,
+
+ /// Sorting expression for spilling batches
+ spill_expr: Vec<PhysicalSortExpr>,
+
+ /// Schema for spilling batches
+ spill_schema: SchemaRef,
+
+ /// true when streaming merge is in progress
+ is_stream_merging: bool,
+
+ /// aggregate_arguments for merging spilled data
+ merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+
+ /// GROUP BY expressions for merging spilled data
+ merging_group_by: PhysicalGroupBy,
+}
+
/// HashTable based Grouping Aggregator
///
/// # Design Goals
@@ -120,6 +151,57 @@ use super::AggregateExec;
/// hash table).
///
/// [`group_values`]: Self::group_values
+///
+/// # Spilling
+///
+/// The sizes of group values and accumulators can become large. Before that
causes out of memory,
+/// this hash aggregator outputs partial states early for partial aggregation
or spills to local
+/// disk using Arrow IPC format for final aggregation. For every input
[`RecordBatch`], the memory
+/// manager checks whether the new input size meets the memory configuration.
If not, outputting or
+/// spilling happens. For outputting, the final aggregation takes care of
re-grouping. For spilling,
+/// later stream-merge sort on reading back the spilled data does re-grouping.
Note the rows cannot
+/// be grouped once spilled onto disk, the read back data needs to be
re-grouped again. In addition,
+/// re-grouping may cause out of memory again. Thus, re-grouping has to be a
sort based aggregation.
+///
+/// ```text
+/// Partial Aggregation [batch_size = 2] (max memory = 3 rows)
+///
+/// INPUTS PARTIALLY AGGREGATED (UPDATE BATCH) OUTPUTS
+/// ┌─────────┐ ┌─────────────────┐ ┌─────────────────┐
+/// │ a │ b │ │ a │ AVG(b) │ │ a │ AVG(b) │
+/// │---│-----│ │ │[count]│[sum]│ │ │[count]│[sum]│
+/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│ │---│-------│-----│
+/// │ 2 │ 2.0 │ │ 2 │ 1 │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1 │ 2.0 │
+/// └─────────┘ │ 3 │ 2 │ 7.0 │ │ │ 3 │ 2 │ 7.0 │
+/// ┌─────────┐ ─▶ │ 4 │ 1 │ 8.0 │ │ └─────────────────┘
+/// │ 3 │ 4.0 │ └─────────────────┘ └▶ ┌─────────────────┐
+/// │ 4 │ 8.0 │ ┌─────────────────┐ │ 4 │ 1 │ 8.0 │
+/// └─────────┘ │ a │ AVG(b) │ ┌▶ │ 1 │ 1 │ 1.0 │
+/// ┌─────────┐ │---│-------│-----│ │ └─────────────────┘
+/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1 │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐
+/// │ 3 │ 2.0 │ │ 3 │ 1 │ 2.0 │ │ 3 │ 1 │ 2.0 │
+/// └─────────┘ └─────────────────┘ └─────────────────┘
+///
+///
+/// Final Aggregation [batch_size = 2] (max memory = 3 rows)
+///
+/// PARTIALLY INPUTS FINAL AGGREGATION (MERGE BATCH) RE-GROUPED
(SORTED)
+/// ┌─────────────────┐ [keep using the partial schema] [Real final
aggregation
+/// │ a │ AVG(b) │ ┌─────────────────┐ output]
+/// │ │[count]│[sum]│ │ a │ AVG(b) │ ┌────────────┐
+/// │---│-------│-----│ ─▶ │ │[count]│[sum]│ │ a │ AVG(b) │
+/// │ 3 │ 3 │ 3.0 │ │---│-------│-----│ ─▶ spill ─┐ │---│--------│
+/// │ 2 │ 2 │ 1.0 │ │ 2 │ 2 │ 1.0 │ │ │ 1 │ 4.0 │
+/// └─────────────────┘ │ 3 │ 4 │ 8.0 │ ▼ │ 2 │ 1.0 │
+/// ┌─────────────────┐ ─▶ │ 4 │ 1 │ 7.0 │ Streaming ─▶ └────────────┘
+/// │ 3 │ 1 │ 5.0 │ └─────────────────┘ merge sort ─▶ ┌────────────┐
+/// │ 4 │ 1 │ 7.0 │ ┌─────────────────┐ ▲ │ a │ AVG(b) │
+/// └─────────────────┘ │ a │ AVG(b) │ │ │---│--------│
+/// ┌─────────────────┐ │---│-------│-----│ ─▶ memory ─┘ │ 3 │ 2.0 │
+/// │ 1 │ 2 │ 8.0 │ ─▶ │ 1 │ 2 │ 8.0 │ │ 4 │ 7.0 │
+/// │ 2 │ 2 │ 3.0 │ │ 2 │ 2 │ 3.0 │ └────────────┘
+/// └─────────────────┘ └─────────────────┘
+/// ```
pub(crate) struct GroupedHashAggregateStream {
schema: SchemaRef,
input: SendableRecordBatchStream,
@@ -178,6 +260,12 @@ pub(crate) struct GroupedHashAggregateStream {
/// Have we seen the end of the input
input_done: bool,
+
+ /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
+ runtime: Arc<RuntimeEnv>,
+
+ /// The spill state object
+ spill_state: SpillState,
}
impl GroupedHashAggregateStream {
@@ -207,6 +295,12 @@ impl GroupedHashAggregateStream {
&agg.mode,
agg_group_by.expr.len(),
)?;
+ // arguments for aggregating spilled data is the same as the one for
final aggregation
+ let merging_aggregate_arguments = aggregates::aggregate_expressions(
+ &agg.aggr_expr,
+ &AggregateMode::Final,
+ agg_group_by.expr.len(),
+ )?;
let filter_expressions = match agg.mode {
AggregateMode::Partial
@@ -224,6 +318,14 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;
let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+ let spill_expr = group_schema
+ .fields
+ .into_iter()
+ .map(|field| PhysicalSortExpr {
+ expr: col(field.name(), &group_schema).unwrap(),
+ options: SortOptions::default(),
+ })
+ .collect();
let name = format!("GroupedHashAggregateStream[{partition}]");
let reservation =
MemoryConsumer::new(name).register(context.memory_pool());
@@ -243,6 +345,15 @@ impl GroupedHashAggregateStream {
let exec_state = ExecutionState::ReadingInput;
+ let spill_state = SpillState {
+ spills: vec![],
+ spill_expr,
+ spill_schema: agg_schema.clone(),
+ is_stream_merging: false,
+ merging_aggregate_arguments,
+ merging_group_by:
PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
+ };
+
Ok(GroupedHashAggregateStream {
schema: agg_schema,
input,
@@ -259,6 +370,8 @@ impl GroupedHashAggregateStream {
batch_size,
group_ordering,
input_done: false,
+ runtime: context.runtime_env(),
+ spill_state,
})
}
}
@@ -310,6 +423,9 @@ impl Stream for GroupedHashAggregateStream {
// new batch to aggregate
Some(Ok(batch)) => {
let timer = elapsed_compute.timer();
+ // Make sure we have enough capacity for `batch`,
otherwise spill
+
extract_ok!(self.spill_previous_if_necessary(&batch));
+
// Do the grouping
extract_ok!(self.group_aggregate_batch(batch));
@@ -318,9 +434,12 @@ impl Stream for GroupedHashAggregateStream {
assert!(!self.input_done);
if let Some(to_emit) =
self.group_ordering.emit_to() {
- let batch = extract_ok!(self.emit(to_emit));
+ let batch = extract_ok!(self.emit(to_emit,
false));
self.exec_state =
ExecutionState::ProducingOutput(batch);
}
+
+ extract_ok!(self.emit_early_if_necessary());
+
timer.done();
}
Some(Err(e)) => {
@@ -332,8 +451,14 @@ impl Stream for GroupedHashAggregateStream {
self.input_done = true;
self.group_ordering.input_done();
let timer = elapsed_compute.timer();
- let batch = extract_ok!(self.emit(EmitTo::All));
- self.exec_state =
ExecutionState::ProducingOutput(batch);
+ if self.spill_state.spills.is_empty() {
+ let batch = extract_ok!(self.emit(EmitTo::All,
false));
+ self.exec_state =
ExecutionState::ProducingOutput(batch);
+ } else {
+ // If spill files exist, stream-merge them.
+ extract_ok!(self.update_merged_stream());
+ self.exec_state = ExecutionState::ReadingInput;
+ }
timer.done();
}
}
@@ -360,7 +485,13 @@ impl Stream for GroupedHashAggregateStream {
)));
}
- ExecutionState::Done => return Poll::Ready(None),
+ ExecutionState::Done => {
+ // release the memory reservation since sending back
output batch itself needs
+ // some memory reservation, so make some room for it.
+ self.clear_all();
+ let _ = self.update_memory_reservation();
+ return Poll::Ready(None);
+ }
}
}
}
@@ -376,13 +507,26 @@ impl GroupedHashAggregateStream {
/// Perform group-by aggregation for the given [`RecordBatch`].
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
// Evaluate the grouping expressions
- let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
+ let group_by_values = if self.spill_state.is_stream_merging {
+ evaluate_group_by(&self.spill_state.merging_group_by, &batch)?
+ } else {
+ evaluate_group_by(&self.group_by, &batch)?
+ };
// Evaluate the aggregation expressions.
- let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
+ let input_values = if self.spill_state.is_stream_merging {
+ evaluate_many(&self.spill_state.merging_aggregate_arguments,
&batch)?
+ } else {
+ evaluate_many(&self.aggregate_arguments, &batch)?
+ };
// Evaluate the filter expressions, if any, against the inputs
- let filter_values = evaluate_optional(&self.filter_expressions,
&batch)?;
+ let filter_values = if self.spill_state.is_stream_merging {
+ let filter_expressions = vec![None; self.accumulators.len()];
+ evaluate_optional(&filter_expressions, &batch)?
+ } else {
+ evaluate_optional(&self.filter_expressions, &batch)?
+ };
for group_values in &group_by_values {
// calculate the group indices for each input row
@@ -416,7 +560,9 @@ impl GroupedHashAggregateStream {
match self.mode {
AggregateMode::Partial
| AggregateMode::Single
- | AggregateMode::SinglePartitioned => {
+ | AggregateMode::SinglePartitioned
+ if !self.spill_state.is_stream_merging =>
+ {
acc.update_batch(
values,
group_indices,
@@ -424,7 +570,7 @@ impl GroupedHashAggregateStream {
total_num_groups,
)?;
}
- AggregateMode::FinalPartitioned | AggregateMode::Final => {
+ _ => {
// if aggregation is over intermediate states,
// use merge
acc.merge_batch(
@@ -438,7 +584,16 @@ impl GroupedHashAggregateStream {
}
}
- self.update_memory_reservation()
+ match self.update_memory_reservation() {
+ // Here we can ignore `insufficient_capacity_err` because we will
spill later,
+ // but at least one batch should fit in the memory
+ Err(DataFusionError::ResourcesExhausted(_))
+ if self.group_values.len() >= self.batch_size =>
+ {
+ Ok(())
+ }
+ other => other,
+ }
}
fn update_memory_reservation(&mut self) -> Result<()> {
@@ -452,9 +607,14 @@ impl GroupedHashAggregateStream {
/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
- fn emit(&mut self, emit_to: EmitTo) -> Result<RecordBatch> {
+ fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch>
{
+ let schema = if spilling {
+ self.spill_state.spill_schema.clone()
+ } else {
+ self.schema()
+ };
if self.group_values.is_empty() {
- return Ok(RecordBatch::new_empty(self.schema()));
+ return Ok(RecordBatch::new_empty(schema));
}
let mut output = self.group_values.emit(emit_to)?;
@@ -466,6 +626,11 @@ impl GroupedHashAggregateStream {
for acc in self.accumulators.iter_mut() {
match self.mode {
AggregateMode::Partial => output.extend(acc.state(emit_to)?),
+ _ if spilling => {
+ // If spilling, output partial state because the spilled
data will be
+ // merged and re-evaluated later.
+ output.extend(acc.state(emit_to)?)
+ }
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
@@ -473,8 +638,110 @@ impl GroupedHashAggregateStream {
}
}
- self.update_memory_reservation()?;
- let batch = RecordBatch::try_new(self.schema(), output)?;
+ // emit reduces the memory usage. Ignore Err from
update_memory_reservation. Even if it is
+ // over the target memory size after emission, we can emit again
rather than returning Err.
+ let _ = self.update_memory_reservation();
+ let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)
}
+
+ /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the
memory target slightly
+ /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to
disk and clear the
+ /// memory. Currently only [`GroupOrdering::None`] is supported for
spilling.
+ fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) ->
Result<()> {
+ // TODO: support group_ordering for spilling
+ if self.group_values.len() > 0
+ && batch.num_rows() > 0
+ && matches!(self.group_ordering, GroupOrdering::None)
+ && !matches!(self.mode, AggregateMode::Partial)
+ && !self.spill_state.is_stream_merging
+ && self.update_memory_reservation().is_err()
+ {
+ // Use input batch (Partial mode) schema for spilling because
+ // the spilled data will be merged and re-evaluated later.
+ self.spill_state.spill_schema = batch.schema();
+ self.spill()?;
+ self.clear_shrink(batch);
+ }
+ Ok(())
+ }
+
+ /// Emit all rows, sort them, and store them on disk.
+ fn spill(&mut self) -> Result<()> {
+ let emit = self.emit(EmitTo::All, true)?;
+ let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
+ let spillfile =
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
+ let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?;
+ // TODO: slice large `sorted` and write to multiple files in parallel
+ writer.write(&sorted)?;
+ writer.finish()?;
+ self.spill_state.spills.push(spillfile);
+ Ok(())
+ }
+
+ /// Clear memory and shirk capacities to the size of the batch.
+ fn clear_shrink(&mut self, batch: &RecordBatch) {
+ self.group_values.clear_shrink(batch);
+ self.current_group_indices.clear();
+ self.current_group_indices.shrink_to(batch.num_rows());
+ }
+
+ /// Clear memory and shirk capacities to zero.
+ fn clear_all(&mut self) {
+ let s = self.schema();
+ self.clear_shrink(&RecordBatch::new_empty(s));
+ }
+
+ /// Emit if the used memory exceeds the target for partial aggregation.
+ /// Currently only [`GroupOrdering::None`] is supported for early emitting.
+ /// TODO: support group_ordering for early emitting
+ fn emit_early_if_necessary(&mut self) -> Result<()> {
+ if self.group_values.len() >= self.batch_size
+ && matches!(self.group_ordering, GroupOrdering::None)
+ && matches!(self.mode, AggregateMode::Partial)
+ && self.update_memory_reservation().is_err()
+ {
+ let n = self.group_values.len() / self.batch_size *
self.batch_size;
+ let batch = self.emit(EmitTo::First(n), false)?;
+ self.exec_state = ExecutionState::ProducingOutput(batch);
+ }
+ Ok(())
+ }
+
+ /// At this point, all the inputs are read and there are some spills.
+ /// Emit the remaining rows and create a batch.
+ /// Conduct a streaming merge sort between the batch and spilled data.
Since the stream is fully
+ /// sorted, set `self.group_ordering` to Full, then later we can read with
[`EmitTo::First`].
+ fn update_merged_stream(&mut self) -> Result<()> {
+ let batch = self.emit(EmitTo::All, true)?;
+ // clear up memory for streaming_merge
+ self.clear_all();
+ self.update_memory_reservation()?;
+ let mut streams: Vec<SendableRecordBatchStream> = vec![];
+ let expr = self.spill_state.spill_expr.clone();
+ let schema = batch.schema();
+ streams.push(Box::pin(RecordBatchStreamAdapter::new(
+ schema.clone(),
+ futures::stream::once(futures::future::lazy(move |_| {
+ sort_batch(&batch, &expr, None)
+ })),
+ )));
+ for spill in self.spill_state.spills.drain(..) {
+ let stream = read_spill_as_stream(spill, schema.clone())?;
+ streams.push(stream);
+ }
+ self.spill_state.is_stream_merging = true;
+ self.input = streaming_merge(
+ streams,
+ schema,
+ &self.spill_state.spill_expr,
+ self.baseline_metrics.clone(),
+ self.batch_size,
+ None,
+ self.reservation.new_empty(),
+ )?;
+ self.input_done = false;
+ self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
+ Ok(())
+ }
}
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs
b/datafusion/core/src/physical_plan/sorts/sort.rs
index 17b94d51c5..92fb45142e 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -574,7 +574,7 @@ impl Debug for ExternalSorter {
}
}
-fn sort_batch(
+pub(crate) fn sort_batch(
batch: &RecordBatch,
expressions: &[PhysicalSortExpr],
fetch: Option<usize>,
@@ -608,7 +608,7 @@ async fn spill_sorted_batches(
}
}
-fn read_spill_as_stream(
+pub(crate) fn read_spill_as_stream(
path: RefCountedTempFile,
schema: SchemaRef,
) -> Result<SendableRecordBatchStream> {
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 7e8930ce2a..02bb466d44 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -165,6 +165,8 @@ struct FirstValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
+ // Whether merge_batch() is called before
+ is_merge_called: bool,
}
impl FirstValueAccumulator {
@@ -183,6 +185,7 @@ impl FirstValueAccumulator {
is_set: false,
orderings,
ordering_req,
+ is_merge_called: false,
})
}
@@ -198,7 +201,9 @@ impl Accumulator for FirstValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.first.clone()];
result.extend(self.orderings.iter().cloned());
- result.push(ScalarValue::Boolean(Some(self.is_set)));
+ if !self.is_merge_called {
+ result.push(ScalarValue::Boolean(Some(self.is_set)));
+ }
Ok(result)
}
@@ -213,6 +218,7 @@ impl Accumulator for FirstValueAccumulator {
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ self.is_merge_called = true;
// FIRST_VALUE(first1, first2, first3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;
@@ -384,6 +390,8 @@ struct LastValueAccumulator {
orderings: Vec<ScalarValue>,
// Stores the applicable ordering requirement.
ordering_req: LexOrdering,
+ // Whether merge_batch() is called before
+ is_merge_called: bool,
}
impl LastValueAccumulator {
@@ -402,6 +410,7 @@ impl LastValueAccumulator {
is_set: false,
orderings,
ordering_req,
+ is_merge_called: false,
})
}
@@ -417,7 +426,9 @@ impl Accumulator for LastValueAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.last.clone()];
result.extend(self.orderings.clone());
- result.push(ScalarValue::Boolean(Some(self.is_set)));
+ if !self.is_merge_called {
+ result.push(ScalarValue::Boolean(Some(self.is_set)));
+ }
Ok(result)
}
@@ -431,6 +442,7 @@ impl Accumulator for LastValueAccumulator {
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ self.is_merge_called = true;
// LAST_VALUE(last1, last2, last3, ...)
// last index contains is_set flag.
let is_set_idx = states.len() - 1;