This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new bbf9d30e0 feat: Add batch coalescing ability to shuffle reader exec
(#1380)
bbf9d30e0 is described below
commit bbf9d30e0eb265a700863778a3be4a1f74f6e68d
Author: Daniel Tu <[email protected]>
AuthorDate: Wed Jan 21 12:02:45 2026 -0800
feat: Add batch coalescing ability to shuffle reader exec (#1380)
* impl
* fix format and simplify new
---
.../core/src/execution_plans/shuffle_reader.rs | 282 ++++++++++++++++++++-
1 file changed, 276 insertions(+), 6 deletions(-)
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 6776d3597..7de252c9c 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -18,6 +18,7 @@
use async_trait::async_trait;
use datafusion::arrow::ipc::reader::StreamReader;
use datafusion::common::stats::Precision;
+use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer,
PushBatchStatus};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::Debug;
@@ -41,12 +42,14 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::common::runtime::SpawnedTask;
use datafusion::error::{DataFusionError, Result};
-use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
+use datafusion::physical_plan::metrics::{
+ BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+};
use datafusion::physical_plan::{
ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan,
Partitioning,
PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
};
-use futures::{Stream, StreamExt, TryStreamExt};
+use futures::{Stream, StreamExt, TryStreamExt, ready};
use crate::error::BallistaError;
use datafusion::execution::context::TaskContext;
@@ -165,6 +168,7 @@ impl ExecutionPlan for ShuffleReaderExec {
let max_message_size = config.ballista_grpc_client_max_message_size();
let force_remote_read =
config.ballista_shuffle_reader_force_remote_read();
let prefer_flight =
config.ballista_shuffle_reader_remote_prefer_flight();
+ let batch_size = config.batch_size();
if force_remote_read {
debug!(
@@ -200,11 +204,18 @@ impl ExecutionPlan for ShuffleReaderExec {
prefer_flight,
);
- let result = RecordBatchStreamAdapter::new(
- Arc::new(self.schema.as_ref().clone()),
+ let input_stream = Box::pin(RecordBatchStreamAdapter::new(
+ self.schema.clone(),
response_receiver.try_flatten(),
- );
- Ok(Box::pin(result))
+ ));
+
+ Ok(Box::pin(CoalescedShuffleReaderStream::new(
+ input_stream,
+ batch_size,
+ None,
+ &self.metrics,
+ partition,
+ )))
}
fn metrics(&self) -> Option<MetricsSet> {
@@ -594,6 +605,96 @@ async fn fetch_partition_object_store(
))
}
+struct CoalescedShuffleReaderStream {
+ schema: SchemaRef,
+ input: SendableRecordBatchStream,
+ coalescer: LimitedBatchCoalescer,
+ completed: bool,
+ baseline_metrics: BaselineMetrics,
+}
+
+impl CoalescedShuffleReaderStream {
+ pub fn new(
+ input: SendableRecordBatchStream,
+ batch_size: usize,
+ limit: Option<usize>,
+ metrics: &ExecutionPlanMetricsSet,
+ partition: usize,
+ ) -> Self {
+ let schema = input.schema();
+ Self {
+ schema: schema.clone(),
+ input,
+ coalescer: LimitedBatchCoalescer::new(schema, batch_size, limit),
+ completed: false,
+ baseline_metrics: BaselineMetrics::new(metrics, partition),
+ }
+ }
+}
+
+impl Stream for CoalescedShuffleReaderStream {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
+ let _timer = elapsed_compute.timer();
+
+ loop {
+ // If there is already a completed batch ready, return it directly
+ if let Some(batch) = self.coalescer.next_completed_batch() {
+ self.baseline_metrics.record_output(batch.num_rows());
+ return Poll::Ready(Some(Ok(batch)));
+ }
+
+ // If the upstream is completed, then it is completed for this
stream too
+ if self.completed {
+ return Poll::Ready(None);
+ }
+
+ // Pull from upstream
+ match ready!(self.input.poll_next_unpin(cx)) {
+ // If upstream is completed, then flush remaning buffered
batches
+ None => {
+ self.completed = true;
+ if let Err(e) = self.coalescer.finish() {
+ return Poll::Ready(Some(Err(e)));
+ }
+ }
+ // If upstream is not completed, then push to coalescer
+ Some(Ok(batch)) => {
+ if batch.num_rows() > 0 {
+ // Try to push to coalescer
+ match self.coalescer.push_batch(batch) {
+ // If push is successful, then continue
+ Ok(PushBatchStatus::Continue) => {
+ continue;
+ }
+ // If limit is reached, then finish coalescer and
set completed to true
+ Ok(PushBatchStatus::LimitReached) => {
+ self.completed = true;
+ if let Err(e) = self.coalescer.finish() {
+ return Poll::Ready(Some(Err(e)));
+ }
+ }
+ Err(e) => return Poll::Ready(Some(Err(e))),
+ }
+ }
+ }
+ Some(Err(e)) => return Poll::Ready(Some(Err(e))),
+ }
+ }
+ }
+}
+
+impl RecordBatchStream for CoalescedShuffleReaderStream {
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1052,10 +1153,179 @@ mod tests {
.unwrap()
}
+ fn create_custom_test_batch(rows: usize) -> RecordBatch {
+ let schema = create_test_schema();
+
+ // 1. Create number column (0, 1, 2, ..., rows-1)
+ let number_vec: Vec<u32> = (0..rows as u32).collect();
+ let number_array = UInt32Array::from(number_vec);
+
+ // 2. Create string column ("s0", "s1", ..., "s{rows-1}")
+ // Just to fill data, the content is not important
+ let string_vec: Vec<String> = (0..rows).map(|i| format!("s{}",
i)).collect();
+ let string_array = StringArray::from(string_vec);
+
+ RecordBatch::try_new(schema, vec![Arc::new(number_array),
Arc::new(string_array)])
+ .unwrap()
+ }
+
fn create_test_schema() -> SchemaRef {
Arc::new(Schema::new(vec![
Field::new("number", DataType::UInt32, true),
Field::new("str", DataType::Utf8, true),
]))
}
+
+ use datafusion::physical_plan::memory::MemoryStream;
+
+ #[tokio::test]
+ async fn test_coalesce_stream_logic() -> Result<()> {
+ // 1. Create test data - 10 small batches, each with 3 rows
+ let schema = create_test_schema();
+ let small_batch = create_test_batch();
+ let batches = vec![small_batch.clone(); 10];
+
+ // 2. Create mock upstream stream (Input Stream)
+ let input_stream = MemoryStream::try_new(batches, schema.clone(),
None)?;
+ let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+ // 3. Configure Coalescer: target batch size to 10 rows
+ let target_batch_size = 10;
+
+ // 4. Manually build the CoalescedShuffleReaderStream
+ let coalesced_stream = CoalescedShuffleReaderStream::new(
+ input_stream,
+ target_batch_size,
+ None,
+ &ExecutionPlanMetricsSet::new(),
+ 0,
+ );
+
+ // 5. Execute stream and collect results
+ let output_batches =
common::collect(Box::pin(coalesced_stream)).await?;
+
+ // 6. Assertions
+ // Assert A: Data total not lost (30 rows)
+ let total_rows: usize = output_batches.iter().map(|b|
b.num_rows()).sum();
+ assert_eq!(total_rows, 30);
+
+ // Assert B: Batch count reduced (10 -> 3)
+ assert_eq!(output_batches.len(), 3);
+
+ // Assert C: Each batch size is correct (all should be 10)
+ assert_eq!(output_batches[0].num_rows(), 10);
+ assert_eq!(output_batches[1].num_rows(), 10);
+ assert_eq!(output_batches[2].num_rows(), 10);
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_coalesce_stream_remainder_flush() -> Result<()> {
+ let schema = create_test_schema();
+ // Create 10 small batch, each with 3 rows. Total 30 rows.
+ let small_batch = create_test_batch();
+ let batches = vec![small_batch.clone(); 10];
+
+ let input_stream = MemoryStream::try_new(batches, schema.clone(),
None)?;
+ let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+ // Target set to 100 rows.
+ // Because 30 < 100, it can never be filled. Must depend on the
`finish()` mechanism to flush out these 30 rows at the end of the stream.
+ let target_batch_size = 100;
+
+ let coalesced_stream = CoalescedShuffleReaderStream::new(
+ input_stream,
+ target_batch_size,
+ None,
+ &ExecutionPlanMetricsSet::new(),
+ 0,
+ );
+
+ let output_batches =
common::collect(Box::pin(coalesced_stream)).await?;
+
+ // Assertions
+ assert_eq!(output_batches.len(), 1); // Should only have 1 batch
+ assert_eq!(output_batches[0].num_rows(), 30); // Should contain all 30
rows
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_coalesce_stream_large_batch() -> Result<()> {
+ let schema = create_test_schema();
+
+ // 1. Create a large batch (20 rows)
+ let big_batch = create_custom_test_batch(20);
+ let batches = vec![big_batch.clone(); 10]; // Total 200 rows
+
+ let input_stream = MemoryStream::try_new(batches, schema.clone(),
None)?;
+ let input_stream = Box::pin(input_stream) as SendableRecordBatchStream;
+
+ // 2. Target set to small size, 10 rows
+ let target_batch_size = 10;
+
+ let coalesced_stream = CoalescedShuffleReaderStream::new(
+ input_stream,
+ target_batch_size,
+ None,
+ &ExecutionPlanMetricsSet::new(),
+ 0,
+ );
+
+ let output_batches =
common::collect(Box::pin(coalesced_stream)).await?;
+
+ // 3. Validation: It should not split the large batch, but directly
output it
+ // Coalescer will not split the batch if size > (max_batch_size / 2)
+ assert_eq!(output_batches.len(), 10);
+ assert_eq!(output_batches[0].num_rows(), 20);
+
+ Ok(())
+ }
+
+ use futures::stream;
+
+ #[tokio::test]
+ async fn test_coalesce_stream_error_propagation() -> Result<()> {
+ let schema = create_test_schema();
+ let small_batch = create_test_batch(); // 3行
+
+ // 1. Construct a stream with error
+ let batches = vec![
+ Ok(small_batch),
+ Err(DataFusionError::Execution(
+ "Network connection failed".to_string(),
+ )),
+ ];
+
+ // 2. Construct a stream with error
+ let stream = stream::iter(batches);
+ let input_stream =
+ Box::pin(RecordBatchStreamAdapter::new(schema.clone(), stream));
+
+ // 3. Configure Coalescer
+ let target_batch_size = 10;
+
+ let coalesced_stream = CoalescedShuffleReaderStream::new(
+ input_stream,
+ target_batch_size,
+ None,
+ &ExecutionPlanMetricsSet::new(),
+ 0,
+ );
+
+ // 4. Execute stream
+ let result = common::collect(Box::pin(coalesced_stream)).await;
+
+ // 5. Validation
+ assert!(result.is_err());
+ assert!(
+ result
+ .unwrap_err()
+ .to_string()
+ .contains("Network connection failed")
+ );
+
+ Ok(())
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]