alamb commented on code in PR #7400:
URL: https://github.com/apache/arrow-datafusion/pull/7400#discussion_r1324946837


##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -466,15 +659,119 @@ impl GroupedHashAggregateStream {
         for acc in self.accumulators.iter_mut() {
             match self.mode {
                 AggregateMode::Partial => output.extend(acc.state(emit_to)?),
+                _ if spilling => {
+                    // If spilling, output partial state because the spilled 
data will be
+                    // merged and re-evaluated later.
+                    output.extend(acc.state(emit_to)?)
+                }
                 AggregateMode::Final
                 | AggregateMode::FinalPartitioned
                 | AggregateMode::Single
                 | AggregateMode::SinglePartitioned => 
output.push(acc.evaluate(emit_to)?),
             }
         }
 
-        self.update_memory_reservation()?;
-        let batch = RecordBatch::try_new(self.schema(), output)?;
+        // emit reduces the memory usage. Ignore Err from 
update_memory_reservation. Even if it is
+        // over the target memory size after emission, we can emit again 
rather than returning Err.
+        let _ = self.update_memory_reservation();
+        let batch = RecordBatch::try_new(schema, output)?;
         Ok(batch)
     }
+
+    /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the 
memory target slightly
+    /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to 
disk and clear the
+    /// memory. Currently only [`GroupOrdering::None`] is supported for 
spilling.
+    fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) -> 
Result<()> {
+        // TODO: support group_ordering for spilling
+        if self.group_values.len() > 0
+            && batch.num_rows() > 0
+            && matches!(self.group_ordering, GroupOrdering::None)
+            && !matches!(self.mode, AggregateMode::Partial)
+            && self.update_memory_reservation().is_err()
+        {
+            // Use input batch (Partial mode) schema for spilling because
+            // the spilled data will be merged and re-evaluated later.
+            self.spill_state.spill_schema = batch.schema();
+            self.spill()?;
+            self.clear_shrink(batch);
+        }
+        Ok(())
+    }
+
+    /// Emit all rows, sort them, and store them on disk.
+    fn spill(&mut self) -> Result<()> {
+        let emit = self.emit(EmitTo::All, true)?;
+        let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
+        let spillfile = 
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
+        let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?;
+        // TODO: slice large `sorted` and write to multiple files in parallel
+        writer.write(&sorted)?;
+        writer.finish()?;
+        self.spill_state.spills.push(spillfile);
+        Ok(())
+    }
+
+    /// Clear memory and shirk capacities to the size of the batch.
+    fn clear_shrink(&mut self, batch: &RecordBatch) {
+        self.group_values.clear_shrink(batch);
+        self.current_group_indices.clear();
+        self.current_group_indices.shrink_to(batch.num_rows());
+    }
+
+    /// Clear memory and shirk capacities to zero.
+    fn clear_all(&mut self) {
+        let s = self.schema();
+        self.clear_shrink(&RecordBatch::new_empty(s));
+    }
+
+    /// Emit if the used memory exceeds the target for partial aggregation.
+    /// Currently only [`GroupOrdering::None`] is supported for spilling.
+    /// TODO: support group_ordering for spilling
+    fn emit_early_if_necessary(&mut self) -> Result<()> {
+        if self.group_values.len() >= self.batch_size
+            && matches!(self.group_ordering, GroupOrdering::None)
+            && matches!(self.mode, AggregateMode::Partial)
+            && self.update_memory_reservation().is_err()
+        {
+            let n = self.group_values.len() / self.batch_size * 
self.batch_size;
+            let batch = self.emit(EmitTo::First(n), false)?;
+            self.exec_state = ExecutionState::ProducingOutput(batch);
+        }
+        Ok(())
+    }
+
+    /// At this point, all the inputs are read and there are some spills.
+    /// Emit the remaining rows and create a batch.
+    /// Conduct a streaming merge sort between the batch and spilled data. 
Since the stream is fully
+    /// sorted, set `self.group_ordering` to Full, then later we can read with 
[`EmitTo::First`].
+    fn update_merged_stream(&mut self) -> Result<()> {

Review Comment:
   At this point, rather than tracking the `merged-stream` specially, I think 
we could simply replace `self.input` and set the overall group by state back to 
`ReadingInput` to read the data from the sorted input stream (after updating 
`self.group_ordering` and the aggregate expressons, etc)
   
   That might simplify the code



##########
datafusion/physical-expr/src/aggregate/first_last.rs:
##########
@@ -165,6 +165,8 @@ struct FirstValueAccumulator {
     orderings: Vec<ScalarValue>,
     // Stores the applicable ordering requirement.
     ordering_req: LexOrdering,
+    // Whether merge_batch() is called before

Review Comment:
   I don't understand the need for this flag, 
   
   
   I removed it with this diff:
   
   ```diff
   (arrow_dev) alamb@MacBook-Pro-8 arrow-datafusion2 % git diff | cat 
   git diff | cat 
   diff --git a/datafusion/physical-expr/src/aggregate/first_last.rs 
b/datafusion/physical-expr/src/aggregate/first_last.rs
   index 02bb466d44..7e8930ce2a 100644
   --- a/datafusion/physical-expr/src/aggregate/first_last.rs
   +++ b/datafusion/physical-expr/src/aggregate/first_last.rs
   @@ -165,8 +165,6 @@ struct FirstValueAccumulator {
        orderings: Vec<ScalarValue>,
        // Stores the applicable ordering requirement.
        ordering_req: LexOrdering,
   -    // Whether merge_batch() is called before
   -    is_merge_called: bool,
    }
    
    impl FirstValueAccumulator {
   @@ -185,7 +183,6 @@ impl FirstValueAccumulator {
                is_set: false,
                orderings,
                ordering_req,
   -            is_merge_called: false,
            })
        }
    
   @@ -201,9 +198,7 @@ impl Accumulator for FirstValueAccumulator {
        fn state(&self) -> Result<Vec<ScalarValue>> {
            let mut result = vec![self.first.clone()];
            result.extend(self.orderings.iter().cloned());
   -        if !self.is_merge_called {
   -            result.push(ScalarValue::Boolean(Some(self.is_set)));
   -        }
   +        result.push(ScalarValue::Boolean(Some(self.is_set)));
            Ok(result)
        }
    
   @@ -218,7 +213,6 @@ impl Accumulator for FirstValueAccumulator {
        }
    
        fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
   -        self.is_merge_called = true;
            // FIRST_VALUE(first1, first2, first3, ...)
            // last index contains is_set flag.
            let is_set_idx = states.len() - 1;
   @@ -390,8 +384,6 @@ struct LastValueAccumulator {
        orderings: Vec<ScalarValue>,
        // Stores the applicable ordering requirement.
        ordering_req: LexOrdering,
   -    // Whether merge_batch() is called before
   -    is_merge_called: bool,
    }
    
    impl LastValueAccumulator {
   @@ -410,7 +402,6 @@ impl LastValueAccumulator {
                is_set: false,
                orderings,
                ordering_req,
   -            is_merge_called: false,
            })
        }
    
   @@ -426,9 +417,7 @@ impl Accumulator for LastValueAccumulator {
        fn state(&self) -> Result<Vec<ScalarValue>> {
            let mut result = vec![self.last.clone()];
            result.extend(self.orderings.clone());
   -        if !self.is_merge_called {
   -            result.push(ScalarValue::Boolean(Some(self.is_set)));
   -        }
   +        result.push(ScalarValue::Boolean(Some(self.is_set)));
            Ok(result)
        }
    
   @@ -442,7 +431,6 @@ impl Accumulator for LastValueAccumulator {
        }
    
        fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
   -        self.is_merge_called = true;
            // LAST_VALUE(last1, last2, last3, ...)
            // last index contains is_set flag.
            let is_set_idx = states.len() - 1;
   ```
   
   And the tests all still seem to pass



##########
datafusion/core/src/physical_plan/aggregates/group_values/mod.rs:
##########
@@ -42,6 +43,9 @@ pub trait GroupValues: Send {
 
     /// Emits the group values
     fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
+
+    /// clear the contents and shrink the capacity

Review Comment:
   ```suggestion
       /// clear the contents and shrink the capacity to free up memory usage.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to