kazuyukitanimura commented on code in PR #7400:
URL: https://github.com/apache/arrow-datafusion/pull/7400#discussion_r1326512958
##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -466,15 +625,122 @@ impl GroupedHashAggregateStream {
for acc in self.accumulators.iter_mut() {
match self.mode {
AggregateMode::Partial => output.extend(acc.state(emit_to)?),
+ _ if spilling => {
+ // If spilling, output partial state because the spilled
data will be
+ // merged and re-evaluated later.
+ output.extend(acc.state(emit_to)?)
+ }
AggregateMode::Final
| AggregateMode::FinalPartitioned
| AggregateMode::Single
| AggregateMode::SinglePartitioned =>
output.push(acc.evaluate(emit_to)?),
}
}
- self.update_memory_reservation()?;
- let batch = RecordBatch::try_new(self.schema(), output)?;
+ // emit reduces the memory usage. Ignore Err from
update_memory_reservation. Even if it is
+ // over the target memory size after emission, we can emit again
rather than returning Err.
+ let _ = self.update_memory_reservation();
+ let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)
}
+
+ /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the
memory target slightly
+ /// (~ 1 [`RecordBatch`]) for simplicity. In such cases, spill the data to
disk and clear the
+ /// memory. Currently only [`GroupOrdering::None`] is supported for
spilling.
+ fn spill_previous_if_necessary(&mut self, batch: &RecordBatch) ->
Result<()> {
+ // TODO: support group_ordering for spilling
+ if self.group_values.len() > 0
+ && batch.num_rows() > 0
+ && matches!(self.group_ordering, GroupOrdering::None)
+ && !matches!(self.mode, AggregateMode::Partial)
+ && !self.spill_state.is_stream_merging
+ && self.update_memory_reservation().is_err()
+ {
+ // Use input batch (Partial mode) schema for spilling because
+ // the spilled data will be merged and re-evaluated later.
+ self.spill_state.spill_schema = batch.schema();
+ self.spill()?;
+ self.clear_shrink(batch);
+ }
+ Ok(())
+ }
+
+ /// Emit all rows, sort them, and store them on disk.
+ fn spill(&mut self) -> Result<()> {
+ let emit = self.emit(EmitTo::All, true)?;
+ let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
+ let spillfile =
self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
+ let mut writer = IPCWriter::new(spillfile.path(), &emit.schema())?;
+ // TODO: slice large `sorted` and write to multiple files in parallel
+ writer.write(&sorted)?;
+ writer.finish()?;
+ self.spill_state.spills.push(spillfile);
+ Ok(())
+ }
+
+ /// Clear memory and shirk capacities to the size of the batch.
+ fn clear_shrink(&mut self, batch: &RecordBatch) {
+ self.group_values.clear_shrink(batch);
+ self.current_group_indices.clear();
+ self.current_group_indices.shrink_to(batch.num_rows());
+ }
+
+ /// Clear memory and shirk capacities to zero.
+ fn clear_all(&mut self) {
+ let s = self.schema();
+ self.clear_shrink(&RecordBatch::new_empty(s));
+ }
+
+ /// Emit if the used memory exceeds the target for partial aggregation.
+ /// Currently only [`GroupOrdering::None`] is supported for spilling.
+ /// TODO: support group_ordering for spilling
+ fn emit_early_if_necessary(&mut self) -> Result<()> {
+ if self.group_values.len() >= self.batch_size
+ && matches!(self.group_ordering, GroupOrdering::None)
+ && matches!(self.mode, AggregateMode::Partial)
+ && self.update_memory_reservation().is_err()
+ {
+ let n = self.group_values.len() / self.batch_size *
self.batch_size;
+ let batch = self.emit(EmitTo::First(n), false)?;
+ self.exec_state = ExecutionState::ProducingOutput(batch);
+ }
+ Ok(())
+ }
+
+ /// At this point, all the inputs are read and there are some spills.
+ /// Emit the remaining rows and create a batch.
+ /// Conduct a streaming merge sort between the batch and spilled data.
Since the stream is fully
+ /// sorted, set `self.group_ordering` to Full, then later we can read with
[`EmitTo::First`].
+ fn update_merged_stream(&mut self) -> Result<()> {
+ let batch = self.emit(EmitTo::All, true)?;
+ // clear up memory for streaming_merge
+ self.clear_all();
+ self.update_memory_reservation()?;
+ let mut streams: Vec<SendableRecordBatchStream> = vec![];
+ let expr = self.spill_state.spill_expr.clone();
+ let schema = batch.schema();
+ streams.push(Box::pin(RecordBatchStreamAdapter::new(
+ schema.clone(),
+ futures::stream::once(futures::future::lazy(move |_| {
+ sort_batch(&batch, &expr, None)
+ })),
+ )));
+ for spill in self.spill_state.spills.drain(..) {
+ let stream = read_spill_as_stream(spill, schema.clone())?;
+ streams.push(stream);
+ }
+ self.spill_state.is_stream_merging = true;
+ self.input = streaming_merge(
+ streams,
+ schema,
+ &self.spill_state.spill_expr,
+ self.baseline_metrics.clone(),
+ self.batch_size,
+ None,
Review Comment:
For the fetch argument, if I understand your comment correctly.
My understanding of the fetch is the number of rows that it can stop sorting
and returning. Here we need all rows, so it should be `None.`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]