alamb commented on code in PR #4972:
URL: https://github.com/apache/arrow-datafusion/pull/4972#discussion_r1081176906


##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -115,6 +106,14 @@ struct GroupedHashAggregateStreamInner {
     indices: [Vec<Range<usize>>; 2],
 }
 
+#[derive(Debug)]
+/// tracks what phase the aggregation is in
+enum ExecutionState {

Review Comment:
   This used to be tracked using several multi-level `match` statement and a 
`fused` inner stream. Now it is represented explicitly in this stream



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -75,19 +75,10 @@ use hashbrown::raw::RawTable;
 /// [Compact]: datafusion_row::layout::RowType::Compact
 /// [WordAligned]: datafusion_row::layout::RowType::WordAligned
 pub(crate) struct GroupedHashAggregateStream {
-    stream: BoxStream<'static, ArrowResult<RecordBatch>>,
-    schema: SchemaRef,
-}
-
-/// Actual implementation of [`GroupedHashAggregateStream`].
-///
-/// This is wrapped into yet another struct because we need to interact with 
the async memory management subsystem

Review Comment:
   this comment about another struct for memory management is out of date and 
so I folded `GroupedHashAggregateStreamInner` directly into 
`GroupedHashAggregateStream`



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -313,222 +300,227 @@ impl RecordBatchStream for GroupedHashAggregateStream {
     }
 }
 
-/// Perform group-by aggregation for the given [`RecordBatch`].
-///
-/// If successfull, this returns the additional number of bytes that were 
allocated during this process.
-///
-/// TODO: Make this a member function of [`GroupedHashAggregateStream`]

Review Comment:
   DONE!



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -219,91 +221,76 @@ impl GroupedHashAggregateStream {
             batch_size,
             row_group_skip_position: 0,
             indices: [normal_agg_indices, row_agg_indices],
-        };
+        })
+    }
+}
 
-        let stream = futures::stream::unfold(inner, |mut this| async move {
-            let elapsed_compute = this.baseline_metrics.elapsed_compute();
+impl Stream for GroupedHashAggregateStream {
+    type Item = ArrowResult<RecordBatch>;
 
-            loop {
-                let result: ArrowResult<Option<RecordBatch>> =
-                    match this.input.next().await {
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+
+        loop {
+            match self.exec_state {
+                ExecutionState::ReadingInput => {
+                    match ready!(self.input.poll_next_unpin(cx)) {
+                        // new batch to aggregate
                         Some(Ok(batch)) => {
                             let timer = elapsed_compute.timer();
-                            let result = group_aggregate_batch(
-                                &this.mode,
-                                &this.random_state,
-                                &this.group_by,
-                                &this.normal_aggr_expr,
-                                &mut this.row_accumulators,
-                                &mut this.row_converter,
-                                this.row_aggr_layout.clone(),
-                                batch,
-                                &mut this.row_aggr_state,
-                                &this.normal_aggregate_expressions,
-                                &this.row_aggregate_expressions,
-                            );
-
+                            let result = self.group_aggregate_batch(batch);
                             timer.done();
 
                             // allocate memory
                             // This happens AFTER we actually used the memory, 
but simplifies the whole accounting and we are OK with
                             // overshooting a bit. Also this means we either 
store the whole record batch or not.
                             match result.and_then(|allocated| {
-                                
this.row_aggr_state.reservation.try_grow(allocated)
+                                
self.row_aggr_state.reservation.try_grow(allocated)
                             }) {
-                                Ok(_) => continue,
-                                Err(e) => 
Err(ArrowError::ExternalError(Box::new(e))),
+                                Ok(_) => {}
+                                Err(e) => {
+                                    return Poll::Ready(Some(Err(
+                                        ArrowError::ExternalError(Box::new(e)),
+                                    )))
+                                }
                             }
                         }
-                        Some(Err(e)) => Err(e),
+                        // inner had error, return to caller
+                        Some(Err(e)) => return Poll::Ready(Some(Err(e))),
+                        // inner is done, producing output
                         None => {
-                            let timer = 
this.baseline_metrics.elapsed_compute().timer();
-                            let result = create_batch_from_map(
-                                &this.mode,
-                                &this.row_converter,
-                                &this.row_aggr_schema,
-                                this.batch_size,
-                                this.row_group_skip_position,
-                                &mut this.row_aggr_state,
-                                &mut this.row_accumulators,
-                                &this.schema,
-                                &this.indices,
-                            );
-
-                            timer.done();
-                            result
+                            self.exec_state = ExecutionState::ProducingOutput;
                         }
-                    };
-
-                this.row_group_skip_position += this.batch_size;
-                return match result {
-                    Ok(Some(result)) => {
-                        let batch = 
result.record_output(&this.baseline_metrics);
-                        Some((Ok(batch), this))
                     }
-                    Ok(None) => None,
-                    Err(error) => Some((Err(error), this)),
-                };
-            }
-        });
+                }
 
-        // seems like some consumers call this stream even after it returned 
`None`, so let's fuse the stream.
-        let stream = stream.fuse();

Review Comment:
   We used to fuse the stream implicitly -- but it is now handled via 
`ExecutionState::Done`



##########
datafusion/core/src/physical_plan/aggregates/row_hash.rs:
##########
@@ -576,138 +568,131 @@ impl std::fmt::Debug for RowAggregationState {
     }
 }
 
-/// Create a RecordBatch with all group keys and accumulator' states or values.
-#[allow(clippy::too_many_arguments)]

Review Comment:
   likewise here, moved from a free function to a member function on 
`GroupedHashAggregateStream`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to