This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new f1f8d79ab8 feat: Support spilling for hash aggregation (#7400)
f1f8d79ab8 is described below

commit f1f8d79ab80039ca531e0b20cd4ac4381a073713
Author: KAZUYUKI TANIMURA <[email protected]>
AuthorDate: Fri Sep 15 04:11:43 2023 -0700

    feat: Support spilling for hash aggregation (#7400)
    
    * Support spilling for hash aggregation
    
    * clippy
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * a
    
    * address review comments
    
    * address review comments
    
    * address review comments
    
    * address review comments
---
 .../physical_plan/aggregates/group_values/mod.rs   |   4 +
 .../aggregates/group_values/primitive.rs           |   9 +
 .../physical_plan/aggregates/group_values/row.rs   |  12 +
 .../core/src/physical_plan/aggregates/mod.rs       | 195 +++++++++++---
 .../src/physical_plan/aggregates/order/partial.rs  |   2 +-
 .../core/src/physical_plan/aggregates/row_hash.rs  | 299 +++++++++++++++++++--
 datafusion/core/src/physical_plan/sorts/sort.rs    |   4 +-
 .../physical-expr/src/aggregate/first_last.rs      |  16 +-
 8 files changed, 482 insertions(+), 59 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
index f10f83dfe3..cafa385eac 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/mod.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use arrow::record_batch::RecordBatch;
 use arrow_array::{downcast_primitive, ArrayRef};
 use arrow_schema::SchemaRef;
 use datafusion_common::Result;
@@ -42,6 +43,9 @@ pub trait GroupValues: Send {
 
     /// Emits the group values
     fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
+
+    /// Clear the contents and shrink the capacity to the size of the batch 
(free up memory usage)
+    fn clear_shrink(&mut self, batch: &RecordBatch);
 }
 
 pub fn new_group_values(schema: SchemaRef) -> Result<Box<dyn GroupValues>> {
diff --git 
a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs 
b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
index d7989fb8c4..7a52729d20 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/primitive.rs
@@ -20,6 +20,7 @@ use ahash::RandomState;
 use arrow::array::BooleanBufferBuilder;
 use arrow::buffer::NullBuffer;
 use arrow::datatypes::i256;
+use arrow::record_batch::RecordBatch;
 use arrow_array::cast::AsArray;
 use arrow_array::{ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, 
PrimitiveArray};
 use arrow_schema::DataType;
@@ -206,4 +207,12 @@ where
         };
         Ok(vec![Arc::new(array.with_data_type(self.data_type.clone()))])
     }
+
+    fn clear_shrink(&mut self, batch: &RecordBatch) {
+        let count = batch.num_rows();
+        self.values.clear();
+        self.values.shrink_to(count);
+        self.map.clear();
+        self.map.shrink_to(count, |_| 0); // hasher does not matter since the 
map is cleared
+    }
 }
diff --git a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs 
b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
index 4eb660d525..d711a16191 100644
--- a/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
+++ b/datafusion/core/src/physical_plan/aggregates/group_values/row.rs
@@ -17,6 +17,7 @@
 
 use crate::physical_plan::aggregates::group_values::GroupValues;
 use ahash::RandomState;
+use arrow::record_batch::RecordBatch;
 use arrow::row::{RowConverter, Rows, SortField};
 use arrow_array::ArrayRef;
 use arrow_schema::SchemaRef;
@@ -181,4 +182,15 @@ impl GroupValues for GroupValuesRows {
             }
         })
     }
+
+    fn clear_shrink(&mut self, batch: &RecordBatch) {
+        let count = batch.num_rows();
+        // FIXME: there is no good way to clear_shrink for self.group_values
+        self.group_values = self.row_converter.empty_rows(count, 0);
+        self.map.clear();
+        self.map.shrink_to(count, |_| 0); // hasher does not matter since the 
map is cleared
+        self.map_size = self.map.capacity() * std::mem::size_of::<(u64, 
usize)>();
+        self.hashes_buffer.clear();
+        self.hashes_buffer.shrink_to(count);
+    }
 }
diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/mod.rs
index bb3f1edfa8..bbc2b949e2 100644
--- a/datafusion/core/src/physical_plan/aggregates/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/mod.rs
@@ -1296,6 +1296,7 @@ mod tests {
     use std::sync::Arc;
     use std::task::{Context, Poll};
 
+    use datafusion_execution::config::SessionConfig;
     use futures::{FutureExt, Stream};
 
     // Generate a schema which consists of 5 columns (a, b, c, d, e)
@@ -1466,7 +1467,22 @@ mod tests {
         )
     }
 
-    async fn check_grouping_sets(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+    fn new_spill_ctx(batch_size: usize, max_memory: usize) -> Arc<TaskContext> 
{
+        let session_config = SessionConfig::new().with_batch_size(batch_size);
+        let runtime = Arc::new(
+            
RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(max_memory, 1.0))
+                .unwrap(),
+        );
+        let task_ctx = TaskContext::default()
+            .with_session_config(session_config)
+            .with_runtime(runtime);
+        Arc::new(task_ctx)
+    }
+
+    async fn check_grouping_sets(
+        input: Arc<dyn ExecutionPlan>,
+        spill: bool,
+    ) -> Result<()> {
         let input_schema = input.schema();
 
         let grouping_set = PhysicalGroupBy {
@@ -1491,7 +1507,11 @@ mod tests {
             DataType::Int64,
         ))];
 
-        let task_ctx = Arc::new(TaskContext::default());
+        let task_ctx = if spill {
+            new_spill_ctx(4, 1000)
+        } else {
+            Arc::new(TaskContext::default())
+        };
 
         let partial_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Partial,
@@ -1506,24 +1526,53 @@ mod tests {
         let result =
             common::collect(partial_aggregate.execute(0, 
task_ctx.clone())?).await?;
 
-        let expected = vec![
-            "+---+-----+-----------------+",
-            "| a | b   | COUNT(1)[count] |",
-            "+---+-----+-----------------+",
-            "|   | 1.0 | 2               |",
-            "|   | 2.0 | 2               |",
-            "|   | 3.0 | 2               |",
-            "|   | 4.0 | 2               |",
-            "| 2 |     | 2               |",
-            "| 2 | 1.0 | 2               |",
-            "| 3 |     | 3               |",
-            "| 3 | 2.0 | 2               |",
-            "| 3 | 3.0 | 1               |",
-            "| 4 |     | 3               |",
-            "| 4 | 3.0 | 1               |",
-            "| 4 | 4.0 | 2               |",
-            "+---+-----+-----------------+",
-        ];
+        let expected = if spill {
+            vec![
+                "+---+-----+-----------------+",
+                "| a | b   | COUNT(1)[count] |",
+                "+---+-----+-----------------+",
+                "|   | 1.0 | 1               |",
+                "|   | 1.0 | 1               |",
+                "|   | 2.0 | 1               |",
+                "|   | 2.0 | 1               |",
+                "|   | 3.0 | 1               |",
+                "|   | 3.0 | 1               |",
+                "|   | 4.0 | 1               |",
+                "|   | 4.0 | 1               |",
+                "| 2 |     | 1               |",
+                "| 2 |     | 1               |",
+                "| 2 | 1.0 | 1               |",
+                "| 2 | 1.0 | 1               |",
+                "| 3 |     | 1               |",
+                "| 3 |     | 2               |",
+                "| 3 | 2.0 | 2               |",
+                "| 3 | 3.0 | 1               |",
+                "| 4 |     | 1               |",
+                "| 4 |     | 2               |",
+                "| 4 | 3.0 | 1               |",
+                "| 4 | 4.0 | 2               |",
+                "+---+-----+-----------------+",
+            ]
+        } else {
+            vec![
+                "+---+-----+-----------------+",
+                "| a | b   | COUNT(1)[count] |",
+                "+---+-----+-----------------+",
+                "|   | 1.0 | 2               |",
+                "|   | 2.0 | 2               |",
+                "|   | 3.0 | 2               |",
+                "|   | 4.0 | 2               |",
+                "| 2 |     | 2               |",
+                "| 2 | 1.0 | 2               |",
+                "| 3 |     | 3               |",
+                "| 3 | 2.0 | 2               |",
+                "| 3 | 3.0 | 1               |",
+                "| 4 |     | 3               |",
+                "| 4 | 3.0 | 1               |",
+                "| 4 | 4.0 | 2               |",
+                "+---+-----+-----------------+",
+            ]
+        };
         assert_batches_sorted_eq!(expected, &result);
 
         let groups = partial_aggregate.group_expr().expr().to_vec();
@@ -1537,6 +1586,12 @@ mod tests {
 
         let final_grouping_set = PhysicalGroupBy::new_single(final_group);
 
+        let task_ctx = if spill {
+            new_spill_ctx(4, 3160)
+        } else {
+            task_ctx
+        };
+
         let merged_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Final,
             final_grouping_set,
@@ -1582,7 +1637,7 @@ mod tests {
     }
 
     /// build the aggregates on the data from some_data() and check the results
-    async fn check_aggregates(input: Arc<dyn ExecutionPlan>) -> Result<()> {
+    async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> 
Result<()> {
         let input_schema = input.schema();
 
         let grouping_set = PhysicalGroupBy {
@@ -1597,7 +1652,11 @@ mod tests {
             DataType::Float64,
         ))];
 
-        let task_ctx = Arc::new(TaskContext::default());
+        let task_ctx = if spill {
+            new_spill_ctx(2, 2144)
+        } else {
+            Arc::new(TaskContext::default())
+        };
 
         let partial_aggregate = Arc::new(AggregateExec::try_new(
             AggregateMode::Partial,
@@ -1612,15 +1671,29 @@ mod tests {
         let result =
             common::collect(partial_aggregate.execute(0, 
task_ctx.clone())?).await?;
 
-        let expected = [
-            "+---+---------------+-------------+",
-            "| a | AVG(b)[count] | AVG(b)[sum] |",
-            "+---+---------------+-------------+",
-            "| 2 | 2             | 2.0         |",
-            "| 3 | 3             | 7.0         |",
-            "| 4 | 3             | 11.0        |",
-            "+---+---------------+-------------+",
-        ];
+        let expected = if spill {
+            vec![
+                "+---+---------------+-------------+",
+                "| a | AVG(b)[count] | AVG(b)[sum] |",
+                "+---+---------------+-------------+",
+                "| 2 | 1             | 1.0         |",
+                "| 2 | 1             | 1.0         |",
+                "| 3 | 1             | 2.0         |",
+                "| 3 | 2             | 5.0         |",
+                "| 4 | 3             | 11.0        |",
+                "+---+---------------+-------------+",
+            ]
+        } else {
+            vec![
+                "+---+---------------+-------------+",
+                "| a | AVG(b)[count] | AVG(b)[sum] |",
+                "+---+---------------+-------------+",
+                "| 2 | 2             | 2.0         |",
+                "| 3 | 3             | 7.0         |",
+                "| 4 | 3             | 11.0        |",
+                "+---+---------------+-------------+",
+            ]
+        };
         assert_batches_sorted_eq!(expected, &result);
 
         let merge = Arc::new(CoalescePartitionsExec::new(partial_aggregate));
@@ -1663,7 +1736,13 @@ mod tests {
 
         let metrics = merged_aggregate.metrics().unwrap();
         let output_rows = metrics.output_rows().unwrap();
-        assert_eq!(3, output_rows);
+        if spill {
+            // When spilling, the output rows metrics become partial output 
size + final output size
+            // This is because final aggregation starts while partial 
aggregation is still emitting
+            assert_eq!(8, output_rows);
+        } else {
+            assert_eq!(3, output_rows);
+        }
 
         Ok(())
     }
@@ -1784,7 +1863,7 @@ mod tests {
         let input: Arc<dyn ExecutionPlan> =
             Arc::new(TestYieldingExec { yield_first: false });
 
-        check_aggregates(input).await
+        check_aggregates(input, false).await
     }
 
     #[tokio::test]
@@ -1792,7 +1871,7 @@ mod tests {
         let input: Arc<dyn ExecutionPlan> =
             Arc::new(TestYieldingExec { yield_first: false });
 
-        check_grouping_sets(input).await
+        check_grouping_sets(input, false).await
     }
 
     #[tokio::test]
@@ -1800,7 +1879,7 @@ mod tests {
         let input: Arc<dyn ExecutionPlan> =
             Arc::new(TestYieldingExec { yield_first: true });
 
-        check_aggregates(input).await
+        check_aggregates(input, false).await
     }
 
     #[tokio::test]
@@ -1808,7 +1887,39 @@ mod tests {
         let input: Arc<dyn ExecutionPlan> =
             Arc::new(TestYieldingExec { yield_first: true });
 
-        check_grouping_sets(input).await
+        check_grouping_sets(input, false).await
+    }
+
+    #[tokio::test]
+    async fn aggregate_source_not_yielding_with_spill() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: false });
+
+        check_aggregates(input, true).await
+    }
+
+    #[tokio::test]
+    async fn aggregate_grouping_sets_source_not_yielding_with_spill() -> 
Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: false });
+
+        check_grouping_sets(input, true).await
+    }
+
+    #[tokio::test]
+    async fn aggregate_source_with_yielding_with_spill() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: true });
+
+        check_aggregates(input, true).await
+    }
+
+    #[tokio::test]
+    async fn aggregate_grouping_sets_with_yielding_with_spill() -> Result<()> {
+        let input: Arc<dyn ExecutionPlan> =
+            Arc::new(TestYieldingExec { yield_first: true });
+
+        check_grouping_sets(input, true).await
     }
 
     #[tokio::test]
@@ -1976,7 +2087,10 @@ mod tests {
     async fn run_first_last_multi_partitions() -> Result<()> {
         for use_coalesce_batches in [false, true] {
             for is_first_acc in [false, true] {
-                first_last_multi_partitions(use_coalesce_batches, 
is_first_acc).await?
+                for spill in [false, true] {
+                    first_last_multi_partitions(use_coalesce_batches, 
is_first_acc, spill)
+                        .await?
+                }
             }
         }
         Ok(())
@@ -2002,8 +2116,13 @@ mod tests {
     async fn first_last_multi_partitions(
         use_coalesce_batches: bool,
         is_first_acc: bool,
+        spill: bool,
     ) -> Result<()> {
-        let task_ctx = Arc::new(TaskContext::default());
+        let task_ctx = if spill {
+            new_spill_ctx(2, 2812)
+        } else {
+            Arc::new(TaskContext::default())
+        };
 
         let (schema, data) = some_data_v2();
         let partition1 = data[0].clone();
diff --git a/datafusion/core/src/physical_plan/aggregates/order/partial.rs 
b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
index 019e61ef26..0feac3a5ed 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/partial.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/partial.rs
@@ -241,7 +241,7 @@ impl GroupOrderingPartial {
         Ok(())
     }
 
-    /// Return the size of memor allocated by this structure
+    /// Return the size of memory allocated by this structure
     pub(crate) fn size(&self) -> usize {
         std::mem::size_of::<Self>()
             + self.order_indices.allocated_size()
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index d034bd669e..eef25c1dc2 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -18,7 +18,7 @@
 //! Hash aggregation
 
 use datafusion_physical_expr::{
-    AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter,
+    AggregateExpr, EmitTo, GroupsAccumulator, GroupsAccumulatorAdapter, 
PhysicalSortExpr,
 };
 use log::debug;
 use std::sync::Arc;
@@ -29,19 +29,28 @@ use futures::ready;
 use futures::stream::{Stream, StreamExt};
 
 use crate::physical_plan::aggregates::group_values::{new_group_values, 
GroupValues};
+use crate::physical_plan::aggregates::order::GroupOrderingFull;
 use crate::physical_plan::aggregates::{
     evaluate_group_by, evaluate_many, evaluate_optional, group_schema, 
AggregateMode,
     PhysicalGroupBy,
 };
+use crate::physical_plan::common::IPCWriter;
 use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
+use crate::physical_plan::sorts::sort::{read_spill_as_stream, sort_batch};
+use crate::physical_plan::sorts::streaming_merge;
+use crate::physical_plan::stream::RecordBatchStreamAdapter;
 use crate::physical_plan::{aggregates, PhysicalExpr};
 use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
 use arrow::array::*;
 use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
-use datafusion_common::Result;
+use arrow_schema::SortOptions;
+use datafusion_common::{DataFusionError, Result};
+use datafusion_execution::disk_manager::RefCountedTempFile;
 use datafusion_execution::memory_pool::proxy::VecAllocExt;
 use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
+use datafusion_execution::runtime_env::RuntimeEnv;
 use datafusion_execution::TaskContext;
+use datafusion_physical_expr::expressions::col;
 
 #[derive(Debug, Clone)]
 /// This object tracks the aggregation phase (input/output)
@@ -56,6 +65,28 @@ pub(crate) enum ExecutionState {
 use super::order::GroupOrdering;
 use super::AggregateExec;
 
+/// This encapsulates the spilling state
+struct SpillState {
+    /// If data has previously been spilled, the locations of the
+    /// spill files (in Arrow IPC format)
+    spills: Vec<RefCountedTempFile>,
+
+    /// Sorting expression for spilling batches
+    spill_expr: Vec<PhysicalSortExpr>,
+
+    /// Schema for spilling batches
+    spill_schema: SchemaRef,
+
+    /// true when streaming merge is in progress
+    is_stream_merging: bool,
+
+    /// aggregate_arguments for merging spilled data
+    merging_aggregate_arguments: Vec<Vec<Arc<dyn PhysicalExpr>>>,
+
+    /// GROUP BY expressions for merging spilled data
+    merging_group_by: PhysicalGroupBy,
+}
+
 /// HashTable based Grouping Aggregator
 ///
 /// # Design Goals
@@ -120,6 +151,57 @@ use super::AggregateExec;
 /// hash table).
 ///
 /// [`group_values`]: Self::group_values
+///
+/// # Spilling
+///
+/// The sizes of group values and accumulators can become large. Before that 
causes out of memory,
+/// this hash aggregator outputs partial states early for partial aggregation 
or spills to local
+/// disk using Arrow IPC format for final aggregation. For every input 
[`RecordBatch`], the memory
+/// manager checks whether the new input size meets the memory configuration. 
If not, outputting or
+/// spilling happens. For outputting, the final aggregation takes care of 
re-grouping. For spilling,
+/// later stream-merge sort on reading back the spilled data does re-grouping. 
Note the rows cannot
+/// be grouped once spilled onto disk, the read back data needs to be 
re-grouped again. In addition,
+/// re-grouping may cause out of memory again. Thus, re-grouping has to be a 
sort based aggregation.
+///
+/// ```text
+/// Partial Aggregation [batch_size = 2] (max memory = 3 rows)
+///
+///  INPUTS        PARTIALLY AGGREGATED (UPDATE BATCH)   OUTPUTS
+/// ┌─────────┐    ┌─────────────────┐                  ┌─────────────────┐
+/// │ a │ b   │    │ a │    AVG(b)   │                  │ a │    AVG(b)   │
+/// │---│-----│    │   │[count]│[sum]│                  │   │[count]│[sum]│
+/// │ 3 │ 3.0 │ ─▶ │---│-------│-----│                  │---│-------│-----│
+/// │ 2 │ 2.0 │    │ 2 │ 1     │ 2.0 │ ─▶ early emit ─▶ │ 2 │ 1     │ 2.0 │
+/// └─────────┘    │ 3 │ 2     │ 7.0 │               │  │ 3 │ 2     │ 7.0 │
+/// ┌─────────┐ ─▶ │ 4 │ 1     │ 8.0 │               │  └─────────────────┘
+/// │ 3 │ 4.0 │    └─────────────────┘               └▶ ┌─────────────────┐
+/// │ 4 │ 8.0 │    ┌─────────────────┐                  │ 4 │ 1     │ 8.0 │
+/// └─────────┘    │ a │    AVG(b)   │               ┌▶ │ 1 │ 1     │ 1.0 │
+/// ┌─────────┐    │---│-------│-----│               │  └─────────────────┘
+/// │ 1 │ 1.0 │ ─▶ │ 1 │ 1     │ 1.0 │ ─▶ early emit ─▶ ┌─────────────────┐
+/// │ 3 │ 2.0 │    │ 3 │ 1     │ 2.0 │                  │ 3 │ 1     │ 2.0 │
+/// └─────────┘    └─────────────────┘                  └─────────────────┘
+///
+///
+/// Final Aggregation [batch_size = 2] (max memory = 3 rows)
+///
+/// PARTIALLY INPUTS       FINAL AGGREGATION (MERGE BATCH)       RE-GROUPED 
(SORTED)
+/// ┌─────────────────┐    [keep using the partial schema]       [Real final 
aggregation
+/// │ a │    AVG(b)   │    ┌─────────────────┐                    output]
+/// │   │[count]│[sum]│    │ a │    AVG(b)   │                   ┌────────────┐
+/// │---│-------│-----│ ─▶ │   │[count]│[sum]│                   │ a │ AVG(b) │
+/// │ 3 │ 3     │ 3.0 │    │---│-------│-----│ ─▶ spill ─┐       │---│--------│
+/// │ 2 │ 2     │ 1.0 │    │ 2 │ 2     │ 1.0 │           │       │ 1 │    4.0 │
+/// └─────────────────┘    │ 3 │ 4     │ 8.0 │           ▼       │ 2 │    1.0 │
+/// ┌─────────────────┐ ─▶ │ 4 │ 1     │ 7.0 │     Streaming  ─▶ └────────────┘
+/// │ 3 │ 1     │ 5.0 │    └─────────────────┘     merge sort ─▶ ┌────────────┐
+/// │ 4 │ 1     │ 7.0 │    ┌─────────────────┐            ▲      │ a │ AVG(b) │
+/// └─────────────────┘    │ a │    AVG(b)   │            │      │---│--------│
+/// ┌─────────────────┐    │---│-------│-----│ ─▶ memory ─┘      │ 3 │    2.0 │
+/// │ 1 │ 2     │ 8.0 │ ─▶ │ 1 │ 2     │ 8.0 │                   │ 4 │    7.0 │
+/// │ 2 │ 2     │ 3.0 │    │ 2 │ 2     │ 3.0 │                   └────────────┘
+/// └─────────────────┘    └─────────────────┘
+/// ```
 pub(crate) struct GroupedHashAggregateStream {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
@@ -178,6 +260,12 @@ pub(crate) struct GroupedHashAggregateStream {
 
     /// Have we seen the end of the input
     input_done: bool,
+
+    /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
+    runtime: Arc<RuntimeEnv>,
+
+    /// The spill state object
+    spill_state: SpillState,
 }
 
 impl GroupedHashAggregateStream {
@@ -207,6 +295,12 @@ impl GroupedHashAggregateStream {
             &agg.mode,
             agg_group_by.expr.len(),
         )?;
+        // arguments for aggregating spilled data is the same as the one for 
final aggregation
+        let merging_aggregate_arguments = aggregates::aggregate_expressions(
+            &agg.aggr_expr,
+            &AggregateMode::Final,
+            agg_group_by.expr.len(),
+        )?;
 
         let filter_expressions = match agg.mode {
             AggregateMode::Partial
@@ -224,6 +318,14 @@ impl GroupedHashAggregateStream {
             .collect::<Result<_>>()?;
 
         let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
+        let spill_expr = group_schema
+            .fields
+            .into_iter()
+            .map(|field| PhysicalSortExpr {
+                expr: col(field.name(), &group_schema).unwrap(),
+                options: SortOptions::default(),
+            })
+            .collect();
 
         let name = format!("GroupedHashAggregateStream[{partition}]");
         let reservation = 
MemoryConsumer::new(name).register(context.memory_pool());
@@ -243,6 +345,15 @@ impl GroupedHashAggregateStream {
 
         let exec_state = ExecutionState::ReadingInput;
 
+        let spill_state = SpillState {
+            spills: vec![],
+            spill_expr,
+            spill_schema: agg_schema.clone(),
+            is_stream_merging: false,
+            merging_aggregate_arguments,
+            merging_group_by: 
PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
+        };
+
         Ok(GroupedHashAggregateStream {
             schema: agg_schema,
             input,
@@ -259,6 +370,8 @@ impl GroupedHashAggregateStream {
             batch_size,
             group_ordering,
             input_done: false,
+            runtime: context.runtime_env(),
+            spill_state,
         })
     }
 }
@@ -310,6 +423,9 @@ impl Stream for GroupedHashAggregateStream {
                         // new batch to aggregate
                         Some(Ok(batch)) => {
                             let timer = elapsed_compute.timer();
+                            // Make sure we have enough capacity for `batch`, 
otherwise spill
+                            
extract_ok!(self.spill_previous_if_necessary(&batch));
+
                             // Do the grouping
                             extract_ok!(self.group_aggregate_batch(batch));
 
@@ -318,9 +434,12 @@ impl Stream for GroupedHashAggregateStream {
                             assert!(!self.input_done);
 
                             if let Some(to_emit) = 
self.group_ordering.emit_to() {
-                                let batch = extract_ok!(self.emit(to_emit));
+                                let batch = extract_ok!(self.emit(to_emit, 
false));
                                 self.exec_state = 
ExecutionState::ProducingOutput(batch);
                             }
+
+                            extract_ok!(self.emit_early_if_necessary());
+
                             timer.done();
                         }
                         Some(Err(e)) => {
@@ -332,8 +451,14 @@ impl Stream for GroupedHashAggregateStream {
                             self.input_done = true;
                             self.group_ordering.input_done();
                             let timer = elapsed_compute.timer();
-                            let batch = extract_ok!(self.emit(EmitTo::All));
-                            self.exec_state = 
ExecutionState::ProducingOutput(batch);
+                            if self.spill_state.spills.is_empty() {
+                                let batch = extract_ok!(self.emit(EmitTo::All, 
false));
+                                self.exec_state = 
ExecutionState::ProducingOutput(batch);
+                            } else {
+                                // If spill files exist, stream-merge them.
+                                extract_ok!(self.update_merged_stream());
+                                self.exec_state = ExecutionState::ReadingInput;
+                            }
                             timer.done();
                         }
                     }
@@ -360,7 +485,13 @@ impl Stream for GroupedHashAggregateStream {
                     )));
                 }
 
-                ExecutionState::Done => return Poll::Ready(None),
+                ExecutionState::Done => {
+                    // release the memory reservation since sending back 
output batch itself needs
+                    // some memory reservation, so make some room for it.
+                    self.clear_all();
+                    let _ = self.update_memory_reservation();
+                    return Poll::Ready(None);
+                }
             }
         }
     }
@@ -376,13 +507,26 @@ impl GroupedHashAggregateStream {
     /// Perform group-by aggregation for the given [`RecordBatch`].
     fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
         // Evaluate the grouping expressions
-        let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
+        let group_by_values = if self.spill_state.is_stream_merging {
+            evaluate_group_by(&self.spill_state.merging_group_by, &batch)?
+        } else {
+            evaluate_group_by(&self.group_by, &batch)?
+        };
 
         // Evaluate the aggregation expressions.
-        let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
+        let input_values = if self.spill_state.is_stream_merging {
+            evaluate_many(&self.spill_state.merging_aggregate_arguments, 
&batch)?
+        } else {
+            evaluate_many(&self.aggregate_arguments, &batch)?
+        };
 
         // Evaluate the filter expressions, if any, against the inputs
-        let filter_values = evaluate_optional(&self.filter_expressions, 
&batch)?;
+        let filter_values = if self.spill_state.is_stream_merging {
+            let filter_expressions = vec![None; self.accumulators.len()];
+            evaluate_optional(&filter_expressions, &batch)?
+        } else {
+            evaluate_optional(&self.filter_expressions, &batch)?
+        };
 
         for group_values in &group_by_values {
             // calculate the group indices for each input row
@@ -416,7 +560,9 @@ impl GroupedHashAggregateStream {
                 match self.mode {
                     AggregateMode::Partial
                     | AggregateMode::Single
-                    | AggregateMode::SinglePartitioned => {
+                    | AggregateMode::SinglePartitioned
+                        if !self.spill_state.is_stream_merging =>
+                    {
                         acc.update_batch(
                             values,
                             group_indices,
@@ -424,7 +570,7 @@ impl GroupedHashAggregateStream {
                             total_num_groups,
                         )?;
                     }
-                    AggregateMode::FinalPartitioned | AggregateMode::Final => {
+                    _ => {
                         // if aggregation is over intermediate states,
                         // use merge
                         acc.merge_batch(
@@ -438,7 +584,16 @@ impl GroupedHashAggregateStream {
             }
         }
 
-        self.update_memory_reservation()
+        match self.update_memory_reservation() {
+            // Here we can ignore `insufficient_capacity_err` because we will 
spill later,
+            // but at least one batch should fit in the memory
+            Err(DataFusionError::ResourcesExhausted(_))
+                if self.group_values.len() >= self.batch_size =>
+            {
+                Ok(())
+            }
+            other => other,
+        }
     }
 
     fn update_memory_reservation(&mut self) -> Result<()> {
@@ -452,9 +607,14 @@ impl GroupedHashAggregateStream {
 
     /// Create an output RecordBatch with the group keys and
     /// accumulator states/values specified in emit_to
-    fn emit(&mut self, emit_to: EmitTo) -> Result<RecordBatch> {
+    fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> 
{
+        let schema = if spilling {
+            self.spill_state.spill_schema.clone()
+        } else {
+            self.schema()
+        };
         if self.group_values.is_empty() {
-            return Ok(RecordBatch::new_empty(self.schema()));
+            return Ok(RecordBatch::new_empty(schema));
         }
 
         let mut output = self.group_values.emit(emit_to)?;
@@ -466,6 +626,11 @@ impl GroupedHashAggregateStream {
         for acc in self.accumulators.iter_mut() {
             match self.mode {
                 AggregateMode::Partial => output.extend(acc.state(emit_to)?),
+                _ if spilling => {
+                    // If spilling, output partial state because the spilled 
data will be
+                    // merged and re-evaluated later.
+                    output.extend(acc.state(emit_to)?)
+                }
                 AggregateMode::Final
                 | AggregateMode::FinalPartitioned
                 | AggregateMode::Single
@@ -473,8 +638,110 @@ impl GroupedHashAggregateStream {
             }
         }
 
-        self.update_memory_reservation()?;
-        let batch = RecordBatch::try_new(self.schema(), output)?;
+        // emit reduces the memory usage. Ignore Err from 
update_memory_reservation. Even if it is
+        // over the target memory size after emission, we can emit again 
rather than returning Err.
+        let _ = self.update_memory_reservation();
+        let batch = RecordBatch::try_new(schema, output)?;
         Ok(batch)
     }
+
+    /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the 
memory target slightly
+    /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to 
disk and clear the
+    /// memory. Currently only [`GroupOrdering::None`] is supported for 
spilling.
+    fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> 
Result<()> {
+        // TODO: support group_ordering for spilling
+        if self.group_values.len() > 0
+            && batch.num_rows() > 0
+            && matches!(self.group_ordering, GroupOrdering::None)
+            && !matches!(self.mode, AggregateMode::Partial)
+            && !self.spill_state.is_stream_merging
+            && self.update_memory_reservation().is_err()
+        {
+            // Use input batch (Partial mode) schema for spilling because
+            // the spilled data will be merged and re-evaluated later.
+            self.spill_state.spill_schema = batch.schema();
+            self.spill()?;
+            self.clear_shrink(batch);
+        }
+        Ok(())
+    }
+
+    /// Emit all rows, sort them, and store them on disk.
+    fn spill(&mut self) -> Result<()> {
+        let emit = self.emit(EmitTo::All, true)?;
+        let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
+        let spillfile = 
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
+        let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?;
+        // TODO: slice large `sorted` and write to multiple files in parallel
+        writer.write(&sorted)?;
+        writer.finish()?;
+        self.spill_state.spills.push(spillfile);
+        Ok(())
+    }
+
+    /// Clear memory and shirk capacities to the size of the batch.
+    fn clear_shrink(&mut self, batch: &RecordBatch) {
+        self.group_values.clear_shrink(batch);
+        self.current_group_indices.clear();
+        self.current_group_indices.shrink_to(batch.num_rows());
+    }
+
+    /// Clear memory and shirk capacities to zero.
+    fn clear_all(&mut self) {
+        let s = self.schema();
+        self.clear_shrink(&RecordBatch::new_empty(s));
+    }
+
+    /// Emit if the used memory exceeds the target for partial aggregation.
+    /// Currently only [`GroupOrdering::None`] is supported for early emitting.
+    /// TODO: support group_ordering for early emitting
+    fn emit_early_if_necessary(&mut self) -> Result<()> {
+        if self.group_values.len() >= self.batch_size
+            && matches!(self.group_ordering, GroupOrdering::None)
+            && matches!(self.mode, AggregateMode::Partial)
+            && self.update_memory_reservation().is_err()
+        {
+            let n = self.group_values.len() / self.batch_size * 
self.batch_size;
+            let batch = self.emit(EmitTo::First(n), false)?;
+            self.exec_state = ExecutionState::ProducingOutput(batch);
+        }
+        Ok(())
+    }
+
+    /// At this point, all the inputs are read and there are some spills.
+    /// Emit the remaining rows and create a batch.
+    /// Conduct a streaming merge sort between the batch and spilled data. 
Since the stream is fully
+    /// sorted, set `self.group_ordering` to Full, then later we can read with 
[`EmitTo::First`].
+    fn update_merged_stream(&mut self) -> Result<()> {
+        let batch = self.emit(EmitTo::All, true)?;
+        // clear up memory for streaming_merge
+        self.clear_all();
+        self.update_memory_reservation()?;
+        let mut streams: Vec<SendableRecordBatchStream> = vec![];
+        let expr = self.spill_state.spill_expr.clone();
+        let schema = batch.schema();
+        streams.push(Box::pin(RecordBatchStreamAdapter::new(
+            schema.clone(),
+            futures::stream::once(futures::future::lazy(move |_| {
+                sort_batch(&batch, &expr, None)
+            })),
+        )));
+        for spill in self.spill_state.spills.drain(..) {
+            let stream = read_spill_as_stream(spill, schema.clone())?;
+            streams.push(stream);
+        }
+        self.spill_state.is_stream_merging = true;
+        self.input = streaming_merge(
+            streams,
+            schema,
+            &self.spill_state.spill_expr,
+            self.baseline_metrics.clone(),
+            self.batch_size,
+            None,
+            self.reservation.new_empty(),
+        )?;
+        self.input_done = false;
+        self.group_ordering = GroupOrdering::Full(GroupOrderingFull::new());
+        Ok(())
+    }
 }
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs 
b/datafusion/core/src/physical_plan/sorts/sort.rs
index 17b94d51c5..92fb45142e 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -574,7 +574,7 @@ impl Debug for ExternalSorter {
     }
 }
 
-fn sort_batch(
+pub(crate) fn sort_batch(
     batch: &RecordBatch,
     expressions: &[PhysicalSortExpr],
     fetch: Option<usize>,
@@ -608,7 +608,7 @@ async fn spill_sorted_batches(
     }
 }
 
-fn read_spill_as_stream(
+pub(crate) fn read_spill_as_stream(
     path: RefCountedTempFile,
     schema: SchemaRef,
 ) -> Result<SendableRecordBatchStream> {
diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs 
b/datafusion/physical-expr/src/aggregate/first_last.rs
index 7e8930ce2a..02bb466d44 100644
--- a/datafusion/physical-expr/src/aggregate/first_last.rs
+++ b/datafusion/physical-expr/src/aggregate/first_last.rs
@@ -165,6 +165,8 @@ struct FirstValueAccumulator {
     orderings: Vec<ScalarValue>,
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
+    // Whether merge_batch() is called before
+    is_merge_called: bool,
 }
 
 impl FirstValueAccumulator {
@@ -183,6 +185,7 @@ impl FirstValueAccumulator {
             is_set: false,
             orderings,
             ordering_req,
+            is_merge_called: false,
         })
     }
 
@@ -198,7 +201,9 @@ impl Accumulator for FirstValueAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         let mut result = vec![self.first.clone()];
         result.extend(self.orderings.iter().cloned());
-        result.push(ScalarValue::Boolean(Some(self.is_set)));
+        if !self.is_merge_called {
+            result.push(ScalarValue::Boolean(Some(self.is_set)));
+        }
         Ok(result)
     }
 
@@ -213,6 +218,7 @@ impl Accumulator for FirstValueAccumulator {
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.is_merge_called = true;
         // FIRST_VALUE(first1, first2, first3, ...)
         // last index contains is_set flag.
         let is_set_idx = states.len() - 1;
@@ -384,6 +390,8 @@ struct LastValueAccumulator {
     orderings: Vec<ScalarValue>,
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
+    // Whether merge_batch() is called before
+    is_merge_called: bool,
 }
 
 impl LastValueAccumulator {
@@ -402,6 +410,7 @@ impl LastValueAccumulator {
             is_set: false,
             orderings,
             ordering_req,
+            is_merge_called: false,
         })
     }
 
@@ -417,7 +426,9 @@ impl Accumulator for LastValueAccumulator {
     fn state(&self) -> Result<Vec<ScalarValue>> {
         let mut result = vec![self.last.clone()];
         result.extend(self.orderings.clone());
-        result.push(ScalarValue::Boolean(Some(self.is_set)));
+        if !self.is_merge_called {
+            result.push(ScalarValue::Boolean(Some(self.is_set)));
+        }
         Ok(result)
     }
 
@@ -431,6 +442,7 @@ impl Accumulator for LastValueAccumulator {
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+        self.is_merge_called = true;
         // LAST_VALUE(last1, last2, last3, ...)
         // last index contains is_set flag.
         let is_set_idx = states.len() - 1;


Reply via email to