This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 56a2af7741 Propagate .execute() calls immediately in `RepartitionExec` 
(#16093)
56a2af7741 is described below

commit 56a2af7741c43f3102ccbf7e93c31c8aa66869fd
Author: Gabriel <45515538+gabote...@users.noreply.github.com>
AuthorDate: Wed May 28 19:12:51 2025 +0200

    Propagate .execute() calls immediately in `RepartitionExec` (#16093)
    
    * Propagate .execute() calls immediately instead of lazily on the first 
RecordBatch poll
    
    * Address race condition: make consume_input_streams lazily initialize the 
RepartitionExecState if it was not initialized
    
    * Remove atomic bool for checking if the state was initialized
    
    ---------
    
    Co-authored-by: Andrew Lamb <and...@nerdnetworks.org>
---
 datafusion/physical-plan/src/repartition/mod.rs | 204 +++++++++++++++---------
 1 file changed, 132 insertions(+), 72 deletions(-)

diff --git a/datafusion/physical-plan/src/repartition/mod.rs 
b/datafusion/physical-plan/src/repartition/mod.rs
index ee5be01e29..d0ad506664 100644
--- a/datafusion/physical-plan/src/repartition/mod.rs
+++ b/datafusion/physical-plan/src/repartition/mod.rs
@@ -19,6 +19,7 @@
 //! partitions to M output partitions based on a partitioning scheme, 
optionally
 //! maintaining the order of the input rows in the output.
 
+use std::fmt::{Debug, Formatter};
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
@@ -45,7 +46,7 @@ use arrow::compute::take_arrays;
 use arrow::datatypes::{SchemaRef, UInt32Type};
 use datafusion_common::config::ConfigOptions;
 use datafusion_common::utils::transpose;
-use datafusion_common::HashMap;
+use datafusion_common::{internal_err, HashMap};
 use datafusion_common::{not_impl_err, DataFusionError, Result};
 use datafusion_common_runtime::SpawnedTask;
 use datafusion_execution::memory_pool::MemoryConsumer;
@@ -67,9 +68,8 @@ type MaybeBatch = Option<Result<RecordBatch>>;
 type InputPartitionsToCurrentPartitionSender = 
Vec<DistributionSender<MaybeBatch>>;
 type InputPartitionsToCurrentPartitionReceiver = 
Vec<DistributionReceiver<MaybeBatch>>;
 
-/// Inner state of [`RepartitionExec`].
 #[derive(Debug)]
-struct RepartitionExecState {
+struct ConsumingInputStreamsState {
     /// Channels for sending batches from input partitions to output 
partitions.
     /// Key is the partition number.
     channels: HashMap<
@@ -85,16 +85,97 @@ struct RepartitionExecState {
     abort_helper: Arc<Vec<SpawnedTask<()>>>,
 }
 
+/// Inner state of [`RepartitionExec`].
+enum RepartitionExecState {
+    /// Not initialized yet. This is the default state stored in the 
RepartitionExec node
+    /// upon instantiation.
+    NotInitialized,
+    /// Input streams are initialized, but they are still not being consumed. 
The node
+    /// transitions to this state when the arrow's RecordBatch stream is 
created in
+    /// RepartitionExec::execute(), but before any message is polled.
+    InputStreamsInitialized(Vec<(SendableRecordBatchStream, 
RepartitionMetrics)>),
+    /// The input streams are being consumed. The node transitions to this 
state when
+    /// the first message in the arrow's RecordBatch stream is consumed.
+    ConsumingInputStreams(ConsumingInputStreamsState),
+}
+
+impl Default for RepartitionExecState {
+    fn default() -> Self {
+        Self::NotInitialized
+    }
+}
+
+impl Debug for RepartitionExecState {
+    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+        match self {
+            RepartitionExecState::NotInitialized => write!(f, 
"NotInitialized"),
+            RepartitionExecState::InputStreamsInitialized(v) => {
+                write!(f, "InputStreamsInitialized({:?})", v.len())
+            }
+            RepartitionExecState::ConsumingInputStreams(v) => {
+                write!(f, "ConsumingInputStreams({v:?})")
+            }
+        }
+    }
+}
+
 impl RepartitionExecState {
-    fn new(
+    fn ensure_input_streams_initialized(
+        &mut self,
+        input: Arc<dyn ExecutionPlan>,
+        metrics: ExecutionPlanMetricsSet,
+        output_partitions: usize,
+        ctx: Arc<TaskContext>,
+    ) -> Result<()> {
+        if !matches!(self, RepartitionExecState::NotInitialized) {
+            return Ok(());
+        }
+
+        let num_input_partitions = 
input.output_partitioning().partition_count();
+        let mut streams_and_metrics = Vec::with_capacity(num_input_partitions);
+
+        for i in 0..num_input_partitions {
+            let metrics = RepartitionMetrics::new(i, output_partitions, 
&metrics);
+
+            let timer = metrics.fetch_time.timer();
+            let stream = input.execute(i, Arc::clone(&ctx))?;
+            timer.done();
+
+            streams_and_metrics.push((stream, metrics));
+        }
+        *self = 
RepartitionExecState::InputStreamsInitialized(streams_and_metrics);
+        Ok(())
+    }
+
+    fn consume_input_streams(
+        &mut self,
         input: Arc<dyn ExecutionPlan>,
-        partitioning: Partitioning,
         metrics: ExecutionPlanMetricsSet,
+        partitioning: Partitioning,
         preserve_order: bool,
         name: String,
         context: Arc<TaskContext>,
-    ) -> Self {
-        let num_input_partitions = 
input.output_partitioning().partition_count();
+    ) -> Result<&mut ConsumingInputStreamsState> {
+        let streams_and_metrics = match self {
+            RepartitionExecState::NotInitialized => {
+                self.ensure_input_streams_initialized(
+                    input,
+                    metrics,
+                    partitioning.partition_count(),
+                    Arc::clone(&context),
+                )?;
+                let RepartitionExecState::InputStreamsInitialized(value) = 
self else {
+                    // This cannot happen, as 
ensure_input_streams_initialized() was just called,
+                    // but the compiler does not know.
+                    return internal_err!("Programming error: 
RepartitionExecState must be in the InputStreamsInitialized state after calling 
RepartitionExecState::ensure_input_streams_initialized");
+                };
+                value
+            }
+            RepartitionExecState::ConsumingInputStreams(value) => return 
Ok(value),
+            RepartitionExecState::InputStreamsInitialized(value) => value,
+        };
+
+        let num_input_partitions = streams_and_metrics.len();
         let num_output_partitions = partitioning.partition_count();
 
         let (txs, rxs) = if preserve_order {
@@ -129,7 +210,9 @@ impl RepartitionExecState {
 
         // launch one async task per *input* partition
         let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
-        for i in 0..num_input_partitions {
+        for (i, (stream, metrics)) in
+            std::mem::take(streams_and_metrics).into_iter().enumerate()
+        {
             let txs: HashMap<_, _> = channels
                 .iter()
                 .map(|(partition, (tx, _rx, reservation))| {
@@ -137,15 +220,11 @@ impl RepartitionExecState {
                 })
                 .collect();
 
-            let r_metrics = RepartitionMetrics::new(i, num_output_partitions, 
&metrics);
-
             let input_task = 
SpawnedTask::spawn(RepartitionExec::pull_from_input(
-                Arc::clone(&input),
-                i,
+                stream,
                 txs.clone(),
                 partitioning.clone(),
-                r_metrics,
-                Arc::clone(&context),
+                metrics,
             ));
 
             // In a separate task, wait for each input to be done
@@ -158,28 +237,17 @@ impl RepartitionExecState {
             ));
             spawned_tasks.push(wait_for_task);
         }
-
-        Self {
+        *self = Self::ConsumingInputStreams(ConsumingInputStreamsState {
             channels,
             abort_helper: Arc::new(spawned_tasks),
+        });
+        match self {
+            RepartitionExecState::ConsumingInputStreams(value) => Ok(value),
+            _ => unreachable!(),
         }
     }
 }
 
-/// Lazily initialized state
-///
-/// Note that the state is initialized ONCE for all partitions by a single 
task(thread).
-/// This may take a short while.  It is also like that multiple threads
-/// call execute at the same time, because we have just started "target 
partitions" tasks
-/// which is commonly set to the number of CPU cores and all call execute at 
the same time.
-///
-/// Thus, use a **tokio** `OnceCell` for this initialization so as not to 
waste CPU cycles
-/// in a mutex lock but instead allow other threads to do something useful.
-///
-/// Uses a parking_lot `Mutex` to control other accesses as they are very 
short duration
-///  (e.g. removing channels on completion) where the overhead of `await` is 
not warranted.
-type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>;
-
 /// A utility that can be used to partition batches based on [`Partitioning`]
 pub struct BatchPartitioner {
     state: BatchPartitionerState,
@@ -406,8 +474,9 @@ impl BatchPartitioner {
 pub struct RepartitionExec {
     /// Input execution plan
     input: Arc<dyn ExecutionPlan>,
-    /// Inner state that is initialized when the first output stream is 
created.
-    state: LazyState,
+    /// Inner state that is initialized when the parent calls .execute() on 
this node
+    /// and consumed as soon as the parent starts consuming this node.
+    state: Arc<Mutex<RepartitionExecState>>,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
     /// Boolean flag to decide whether to preserve ordering. If true means
@@ -486,11 +555,7 @@ impl RepartitionExec {
 }
 
 impl DisplayAs for RepartitionExec {
-    fn fmt_as(
-        &self,
-        t: DisplayFormatType,
-        f: &mut std::fmt::Formatter,
-    ) -> std::fmt::Result {
+    fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> 
std::fmt::Result {
         match t {
             DisplayFormatType::Default | DisplayFormatType::Verbose => {
                 write!(
@@ -583,7 +648,6 @@ impl ExecutionPlan for RepartitionExec {
             partition
         );
 
-        let lazy_state = Arc::clone(&self.state);
         let input = Arc::clone(&self.input);
         let partitioning = self.partitioning().clone();
         let metrics = self.metrics.clone();
@@ -595,30 +659,31 @@ impl ExecutionPlan for RepartitionExec {
         // Get existing ordering to use for merging
         let sort_exprs = self.sort_exprs().cloned().unwrap_or_default();
 
+        let state = Arc::clone(&self.state);
+        if let Some(mut state) = state.try_lock() {
+            state.ensure_input_streams_initialized(
+                Arc::clone(&input),
+                metrics.clone(),
+                partitioning.partition_count(),
+                Arc::clone(&context),
+            )?;
+        }
+
         let stream = futures::stream::once(async move {
             let num_input_partitions = 
input.output_partitioning().partition_count();
 
-            let input_captured = Arc::clone(&input);
-            let metrics_captured = metrics.clone();
-            let name_captured = name.clone();
-            let context_captured = Arc::clone(&context);
-            let state = lazy_state
-                .get_or_init(|| async move {
-                    Mutex::new(RepartitionExecState::new(
-                        input_captured,
-                        partitioning,
-                        metrics_captured,
-                        preserve_order,
-                        name_captured,
-                        context_captured,
-                    ))
-                })
-                .await;
-
             // lock scope
             let (mut rx, reservation, abort_helper) = {
                 // lock mutexes
                 let mut state = state.lock();
+                let state = state.consume_input_streams(
+                    Arc::clone(&input),
+                    metrics.clone(),
+                    partitioning,
+                    preserve_order,
+                    name.clone(),
+                    Arc::clone(&context),
+                )?;
 
                 // now return stream for the specified *output* partition 
which will
                 // read from the channel
@@ -853,24 +918,17 @@ impl RepartitionExec {
     ///
     /// txs hold the output sending channels for each output partition
     async fn pull_from_input(
-        input: Arc<dyn ExecutionPlan>,
-        partition: usize,
+        mut stream: SendableRecordBatchStream,
         mut output_channels: HashMap<
             usize,
             (DistributionSender<MaybeBatch>, SharedMemoryReservation),
         >,
         partitioning: Partitioning,
         metrics: RepartitionMetrics,
-        context: Arc<TaskContext>,
     ) -> Result<()> {
         let mut partitioner =
             BatchPartitioner::try_new(partitioning, 
metrics.repartition_time.clone())?;
 
-        // execute the child operator
-        let timer = metrics.fetch_time.timer();
-        let mut stream = input.execute(partition, context)?;
-        timer.done();
-
         // While there are still outputs to send to, keep pulling inputs
         let mut batches_until_yield = partitioner.num_partitions();
         while !output_channels.is_empty() {
@@ -1118,6 +1176,7 @@ mod tests {
     use datafusion_common_runtime::JoinSet;
     use datafusion_execution::runtime_env::RuntimeEnvBuilder;
     use insta::assert_snapshot;
+    use itertools::Itertools;
 
     #[tokio::test]
     async fn one_to_many_round_robin() -> Result<()> {
@@ -1298,15 +1357,9 @@ mod tests {
         let partitioning = Partitioning::RoundRobinBatch(1);
         let exec = RepartitionExec::try_new(Arc::new(input), 
partitioning).unwrap();
 
-        // Note: this should pass (the stream can be created) but the
-        // error when the input is executed should get passed back
-        let output_stream = exec.execute(0, task_ctx).unwrap();
-
         // Expect that an error is returned
-        let result_string = crate::common::collect(output_stream)
-            .await
-            .unwrap_err()
-            .to_string();
+        let result_string = exec.execute(0, 
task_ctx).err().unwrap().to_string();
+
         assert!(
             result_string.contains("ErrorExec, unsurprisingly, errored in 
partition 0"),
             "actual: {result_string}"
@@ -1496,7 +1549,14 @@ mod tests {
         });
         let batches_with_drop = 
crate::common::collect(output_stream1).await.unwrap();
 
-        assert_eq!(batches_without_drop, batches_with_drop);
+        fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> {
+            batch
+                .into_iter()
+                .sorted_by_key(|b| format!("{b:?}"))
+                .collect()
+        }
+
+        assert_eq!(sort(batches_without_drop), sort(batches_with_drop));
     }
 
     fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to