This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new bbf9d30e0 feat: Add batch coalescing ability to shuffle reader exec 
(#1380)
bbf9d30e0 is described below

commit bbf9d30e0eb265a700863778a3be4a1f74f6e68d
Author: Daniel Tu <[email protected]>
AuthorDate: Wed Jan 21 12:02:45 2026 -0800

    feat: Add batch coalescing ability to shuffle reader exec (#1380)
    
    * impl
    
    * fix format and simplify new
---
 .../core/src/execution_plans/shuffle_reader.rs     | 282 ++++++++++++++++++++-
 1 file changed, 276 insertions(+), 6 deletions(-)

diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 6776d3597..7de252c9c 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -18,6 +18,7 @@
 use async_trait::async_trait;
 use datafusion::arrow::ipc::reader::StreamReader;
 use datafusion::common::stats::Precision;
+use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, 
PushBatchStatus};
 use std::any::Any;
 use std::collections::HashMap;
 use std::fmt::Debug;
@@ -41,12 +42,14 @@ use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::common::runtime::SpawnedTask;
 
 use datafusion::error::{DataFusionError, Result};
-use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
+use datafusion::physical_plan::metrics::{
+    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+};
 use datafusion::physical_plan::{
     ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, 
Partitioning,
     PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
 };
-use futures::{Stream, StreamExt, TryStreamExt};
+use futures::{Stream, StreamExt, TryStreamExt, ready};
 
 use crate::error::BallistaError;
 use datafusion::execution::context::TaskContext;
@@ -165,6 +168,7 @@ impl ExecutionPlan for ShuffleReaderExec {
         let max_message_size = config.ballista_grpc_client_max_message_size();
         let force_remote_read = 
config.ballista_shuffle_reader_force_remote_read();
         let prefer_flight = 
config.ballista_shuffle_reader_remote_prefer_flight();
+        let batch_size = config.batch_size();
 
         if force_remote_read {
             debug!(
@@ -200,11 +204,18 @@ impl ExecutionPlan for ShuffleReaderExec {
             prefer_flight,
         );
 
-        let result = RecordBatchStreamAdapter::new(
-            Arc::new(self.schema.as_ref().clone()),
+        let input_stream = Box::pin(RecordBatchStreamAdapter::new(
+            self.schema.clone(),
             response_receiver.try_flatten(),
-        );
-        Ok(Box::pin(result))
+        ));
+
+        Ok(Box::pin(CoalescedShuffleReaderStream::new(
+            input_stream,
+            batch_size,
+            None,
+            &self.metrics,
+            partition,
+        )))
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
@@ -594,6 +605,96 @@ async fn fetch_partition_object_store(
     ))
 }
 
+struct CoalescedShuffleReaderStream {
+    schema: SchemaRef,
+    input: SendableRecordBatchStream,
+    coalescer: LimitedBatchCoalescer,
+    completed: bool,
+    baseline_metrics: BaselineMetrics,
+}
+
+impl CoalescedShuffleReaderStream {
+    pub fn new(
+        input: SendableRecordBatchStream,
+        batch_size: usize,
+        limit: Option<usize>,
+        metrics: &ExecutionPlanMetricsSet,
+        partition: usize,
+    ) -> Self {
+        let schema = input.schema();
+        Self {
+            schema: schema.clone(),
+            input,
+            coalescer: LimitedBatchCoalescer::new(schema, batch_size, limit),
+            completed: false,
+            baseline_metrics: BaselineMetrics::new(metrics, partition),
+        }
+    }
+}
+
+impl Stream for CoalescedShuffleReaderStream {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+        let _timer = elapsed_compute.timer();
+
+        loop {
+            // If there is already a completed batch ready, return it directly
+            if let Some(batch) = self.coalescer.next_completed_batch() {
+                self.baseline_metrics.record_output(batch.num_rows());
+                return Poll::Ready(Some(Ok(batch)));
+            }
+
+            // If the upstream is completed, then it is completed for this 
stream too
+            if self.completed {
+                return Poll::Ready(None);
+            }
+
+            // Pull from upstream
+            match ready!(self.input.poll_next_unpin(cx)) {
+                // If upstream is completed, then flush remaning buffered 
batches
+                None => {
+                    self.completed = true;
+                    if let Err(e) = self.coalescer.finish() {
+                        return Poll::Ready(Some(Err(e)));
+                    }
+                }
+                // If upstream is not completed, then push to coalescer
+                Some(Ok(batch)) => {
+                    if batch.num_rows() > 0 {
+                        // Try to push to coalescer
+                        match self.coalescer.push_batch(batch) {
+                            // If push is successful, then continue
+                            Ok(PushBatchStatus::Continue) => {
+                                continue;
+                            }
+                            // If limit is reached, then finish coalescer and 
set completed to true
+                            Ok(PushBatchStatus::LimitReached) => {
+                                self.completed = true;
+                                if let Err(e) = self.coalescer.finish() {
+                                    return Poll::Ready(Some(Err(e)));
+                                }
+                            }
+                            Err(e) => return Poll::Ready(Some(Err(e))),
+                        }
+                    }
+                }
+                Some(Err(e)) => return Poll::Ready(Some(Err(e))),
+            }
+        }
+    }
+}
+
+impl RecordBatchStream for CoalescedShuffleReaderStream {
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -1052,10 +1153,179 @@ mod tests {
         .unwrap()
     }
 
+    fn create_custom_test_batch(rows: usize) -> RecordBatch {
+        let schema = create_test_schema();
+
+        // 1. Create number column (0, 1, 2, ..., rows-1)
+        let number_vec: Vec<u32> = (0..rows as u32).collect();
+        let number_array = UInt32Array::from(number_vec);
+
+        // 2. Create string column ("s0", "s1", ..., "s{rows-1}")
+        // Just to fill data, the content is not important
+        let string_vec: Vec<String> = (0..rows).map(|i| format!("s{}", 
i)).collect();
+        let string_array = StringArray::from(string_vec);
+
+        RecordBatch::try_new(schema, vec![Arc::new(number_array), 
Arc::new(string_array)])
+            .unwrap()
+    }
+
     fn create_test_schema() -> SchemaRef {
         Arc::new(Schema::new(vec![
             Field::new("number", DataType::UInt32, true),
             Field::new("str", DataType::Utf8, true),
         ]))
     }
+
+    use datafusion::physical_plan::memory::MemoryStream;
+
+    #[tokio::test]
+    async fn test_coalesce_stream_logic() -> Result<()> {
+        // 1. Create test data - 10 small batches, each with 3 rows
+        let schema = create_test_schema();
+        let small_batch = create_test_batch();
+        let batches = vec![small_batch.clone(); 10];
+
+        // 2. Create mock upstream stream (Input Stream)
+        let input_stream = MemoryStream::try_new(batches, schema.clone(), 
None)?;
+        let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+        // 3. Configure Coalescer: target batch size to 10 rows
+        let target_batch_size = 10;
+
+        // 4. Manually build the CoalescedShuffleReaderStream
+        let coalesced_stream = CoalescedShuffleReaderStream::new(
+            input_stream,
+            target_batch_size,
+            None,
+            &ExecutionPlanMetricsSet::new(),
+            0,
+        );
+
+        // 5. Execute stream and collect results
+        let output_batches = 
common::collect(Box::pin(coalesced_stream)).await?;
+
+        // 6. Assertions
+        // Assert A: Data total not lost (30 rows)
+        let total_rows: usize = output_batches.iter().map(|b| 
b.num_rows()).sum();
+        assert_eq!(total_rows, 30);
+
+        // Assert B: Batch count reduced (10 -> 3)
+        assert_eq!(output_batches.len(), 3);
+
+        // Assert C: Each batch size is correct (all should be 10)
+        assert_eq!(output_batches[0].num_rows(), 10);
+        assert_eq!(output_batches[1].num_rows(), 10);
+        assert_eq!(output_batches[2].num_rows(), 10);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_coalesce_stream_remainder_flush() -> Result<()> {
+        let schema = create_test_schema();
+        // Create 10 small batch, each with 3 rows. Total 30 rows.
+        let small_batch = create_test_batch();
+        let batches = vec![small_batch.clone(); 10];
+
+        let input_stream = MemoryStream::try_new(batches, schema.clone(), 
None)?;
+        let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+        // Target set to 100 rows.
+        // Because 30 < 100, it can never be filled. Must depend on the 
`finish()` mechanism to flush out these 30 rows at the end of the stream.
+        let target_batch_size = 100;
+
+        let coalesced_stream = CoalescedShuffleReaderStream::new(
+            input_stream,
+            target_batch_size,
+            None,
+            &ExecutionPlanMetricsSet::new(),
+            0,
+        );
+
+        let output_batches = 
common::collect(Box::pin(coalesced_stream)).await?;
+
+        // Assertions
+        assert_eq!(output_batches.len(), 1); // Should only have 1 batch
+        assert_eq!(output_batches[0].num_rows(), 30); // Should contain all 30 
rows
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_coalesce_stream_large_batch() -> Result<()> {
+        let schema = create_test_schema();
+
+        // 1. Create a large batch (20 rows)
+        let big_batch = create_custom_test_batch(20);
+        let batches = vec![big_batch.clone(); 10]; // Total 200 rows
+
+        let input_stream = MemoryStream::try_new(batches, schema.clone(), 
None)?;
+        let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+        // 2. Target set to small size, 10 rows
+        let target_batch_size = 10;
+
+        let coalesced_stream = CoalescedShuffleReaderStream::new(
+            input_stream,
+            target_batch_size,
+            None,
+            &ExecutionPlanMetricsSet::new(),
+            0,
+        );
+
+        let output_batches = 
common::collect(Box::pin(coalesced_stream)).await?;
+
+        // 3. Validation: It should not split the large batch, but directly 
output it
+        // Coalescer will not split the batch if size > (max_batch_size / 2)
+        assert_eq!(output_batches.len(), 10);
+        assert_eq!(output_batches[0].num_rows(), 20);
+
+        Ok(())
+    }
+
+    use futures::stream;
+
+    #[tokio::test]
+    async fn test_coalesce_stream_error_propagation() -> Result<()> {
+        let schema = create_test_schema();
+        let small_batch = create_test_batch(); // 3行
+
+        // 1. Construct a stream with error
+        let batches = vec![
+            Ok(small_batch),
+            Err(DataFusionError::Execution(
+                "Network connection failed".to_string(),
+            )),
+        ];
+
+        // 2. Construct a stream with error
+        let stream = stream::iter(batches);
+        let input_stream =
+            Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream));
+
+        // 3. Configure Coalescer
+        let target_batch_size = 10;
+
+        let coalesced_stream = CoalescedShuffleReaderStream::new(
+            input_stream,
+            target_batch_size,
+            None,
+            &ExecutionPlanMetricsSet::new(),
+            0,
+        );
+
+        // 4. Execute stream
+        let result = common::collect(Box::pin(coalesced_stream)).await;
+
+        // 5. Validation
+        assert!(result.is_err());
+        assert!(
+            result
+                .unwrap_err()
+                .to_string()
+                .contains("Network connection failed")
+        );
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to