This is an automated email from the ASF dual-hosted git repository. alamb pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push: new 56a2af7741 Propagate .execute() calls immediately in `RepartitionExec` (#16093) 56a2af7741 is described below commit 56a2af7741c43f3102ccbf7e93c31c8aa66869fd Author: Gabriel <45515538+gabote...@users.noreply.github.com> AuthorDate: Wed May 28 19:12:51 2025 +0200 Propagate .execute() calls immediately in `RepartitionExec` (#16093) * Propagate .execute() calls immediately instead of lazily on the first RecordBatch poll * Address race condition: make consume_input_streams lazily initialize the RepartitionExecState if it was not initialized * Remove atomic bool for checking if the state was initialized --------- Co-authored-by: Andrew Lamb <and...@nerdnetworks.org> --- datafusion/physical-plan/src/repartition/mod.rs | 204 +++++++++++++++--------- 1 file changed, 132 insertions(+), 72 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index ee5be01e29..d0ad506664 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -19,6 +19,7 @@ //! partitions to M output partitions based on a partitioning scheme, optionally //! maintaining the order of the input rows in the output. +use std::fmt::{Debug, Formatter}; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; @@ -45,7 +46,7 @@ use arrow::compute::take_arrays; use arrow::datatypes::{SchemaRef, UInt32Type}; use datafusion_common::config::ConfigOptions; use datafusion_common::utils::transpose; -use datafusion_common::HashMap; +use datafusion_common::{internal_err, HashMap}; use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::MemoryConsumer; @@ -67,9 +68,8 @@ type MaybeBatch = Option<Result<RecordBatch>>; type InputPartitionsToCurrentPartitionSender = Vec<DistributionSender<MaybeBatch>>; type InputPartitionsToCurrentPartitionReceiver = Vec<DistributionReceiver<MaybeBatch>>; -/// Inner state of [`RepartitionExec`]. #[derive(Debug)] -struct RepartitionExecState { +struct ConsumingInputStreamsState { /// Channels for sending batches from input partitions to output partitions. /// Key is the partition number. channels: HashMap< @@ -85,16 +85,97 @@ struct RepartitionExecState { abort_helper: Arc<Vec<SpawnedTask<()>>>, } +/// Inner state of [`RepartitionExec`]. +enum RepartitionExecState { + /// Not initialized yet. This is the default state stored in the RepartitionExec node + /// upon instantiation. + NotInitialized, + /// Input streams are initialized, but they are still not being consumed. The node + /// transitions to this state when the arrow's RecordBatch stream is created in + /// RepartitionExec::execute(), but before any message is polled. + InputStreamsInitialized(Vec<(SendableRecordBatchStream, RepartitionMetrics)>), + /// The input streams are being consumed. The node transitions to this state when + /// the first message in the arrow's RecordBatch stream is consumed. + ConsumingInputStreams(ConsumingInputStreamsState), +} + +impl Default for RepartitionExecState { + fn default() -> Self { + Self::NotInitialized + } +} + +impl Debug for RepartitionExecState { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + RepartitionExecState::NotInitialized => write!(f, "NotInitialized"), + RepartitionExecState::InputStreamsInitialized(v) => { + write!(f, "InputStreamsInitialized({:?})", v.len()) + } + RepartitionExecState::ConsumingInputStreams(v) => { + write!(f, "ConsumingInputStreams({v:?})") + } + } + } +} + impl RepartitionExecState { - fn new( + fn ensure_input_streams_initialized( + &mut self, + input: Arc<dyn ExecutionPlan>, + metrics: ExecutionPlanMetricsSet, + output_partitions: usize, + ctx: Arc<TaskContext>, + ) -> Result<()> { + if !matches!(self, RepartitionExecState::NotInitialized) { + return Ok(()); + } + + let num_input_partitions = input.output_partitioning().partition_count(); + let mut streams_and_metrics = Vec::with_capacity(num_input_partitions); + + for i in 0..num_input_partitions { + let metrics = RepartitionMetrics::new(i, output_partitions, &metrics); + + let timer = metrics.fetch_time.timer(); + let stream = input.execute(i, Arc::clone(&ctx))?; + timer.done(); + + streams_and_metrics.push((stream, metrics)); + } + *self = RepartitionExecState::InputStreamsInitialized(streams_and_metrics); + Ok(()) + } + + fn consume_input_streams( + &mut self, input: Arc<dyn ExecutionPlan>, - partitioning: Partitioning, metrics: ExecutionPlanMetricsSet, + partitioning: Partitioning, preserve_order: bool, name: String, context: Arc<TaskContext>, - ) -> Self { - let num_input_partitions = input.output_partitioning().partition_count(); + ) -> Result<&mut ConsumingInputStreamsState> { + let streams_and_metrics = match self { + RepartitionExecState::NotInitialized => { + self.ensure_input_streams_initialized( + input, + metrics, + partitioning.partition_count(), + Arc::clone(&context), + )?; + let RepartitionExecState::InputStreamsInitialized(value) = self else { + // This cannot happen, as ensure_input_streams_initialized() was just called, + // but the compiler does not know. + return internal_err!("Programming error: RepartitionExecState must be in the InputStreamsInitialized state after calling RepartitionExecState::ensure_input_streams_initialized"); + }; + value + } + RepartitionExecState::ConsumingInputStreams(value) => return Ok(value), + RepartitionExecState::InputStreamsInitialized(value) => value, + }; + + let num_input_partitions = streams_and_metrics.len(); let num_output_partitions = partitioning.partition_count(); let (txs, rxs) = if preserve_order { @@ -129,7 +210,9 @@ impl RepartitionExecState { // launch one async task per *input* partition let mut spawned_tasks = Vec::with_capacity(num_input_partitions); - for i in 0..num_input_partitions { + for (i, (stream, metrics)) in + std::mem::take(streams_and_metrics).into_iter().enumerate() + { let txs: HashMap<_, _> = channels .iter() .map(|(partition, (tx, _rx, reservation))| { @@ -137,15 +220,11 @@ impl RepartitionExecState { }) .collect(); - let r_metrics = RepartitionMetrics::new(i, num_output_partitions, &metrics); - let input_task = SpawnedTask::spawn(RepartitionExec::pull_from_input( - Arc::clone(&input), - i, + stream, txs.clone(), partitioning.clone(), - r_metrics, - Arc::clone(&context), + metrics, )); // In a separate task, wait for each input to be done @@ -158,28 +237,17 @@ impl RepartitionExecState { )); spawned_tasks.push(wait_for_task); } - - Self { + *self = Self::ConsumingInputStreams(ConsumingInputStreamsState { channels, abort_helper: Arc::new(spawned_tasks), + }); + match self { + RepartitionExecState::ConsumingInputStreams(value) => Ok(value), + _ => unreachable!(), } } } -/// Lazily initialized state -/// -/// Note that the state is initialized ONCE for all partitions by a single task(thread). -/// This may take a short while. It is also like that multiple threads -/// call execute at the same time, because we have just started "target partitions" tasks -/// which is commonly set to the number of CPU cores and all call execute at the same time. -/// -/// Thus, use a **tokio** `OnceCell` for this initialization so as not to waste CPU cycles -/// in a mutex lock but instead allow other threads to do something useful. -/// -/// Uses a parking_lot `Mutex` to control other accesses as they are very short duration -/// (e.g. removing channels on completion) where the overhead of `await` is not warranted. -type LazyState = Arc<tokio::sync::OnceCell<Mutex<RepartitionExecState>>>; - /// A utility that can be used to partition batches based on [`Partitioning`] pub struct BatchPartitioner { state: BatchPartitionerState, @@ -406,8 +474,9 @@ impl BatchPartitioner { pub struct RepartitionExec { /// Input execution plan input: Arc<dyn ExecutionPlan>, - /// Inner state that is initialized when the first output stream is created. - state: LazyState, + /// Inner state that is initialized when the parent calls .execute() on this node + /// and consumed as soon as the parent starts consuming this node. + state: Arc<Mutex<RepartitionExecState>>, /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Boolean flag to decide whether to preserve ordering. If true means @@ -486,11 +555,7 @@ impl RepartitionExec { } impl DisplayAs for RepartitionExec { - fn fmt_as( - &self, - t: DisplayFormatType, - f: &mut std::fmt::Formatter, - ) -> std::fmt::Result { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( @@ -583,7 +648,6 @@ impl ExecutionPlan for RepartitionExec { partition ); - let lazy_state = Arc::clone(&self.state); let input = Arc::clone(&self.input); let partitioning = self.partitioning().clone(); let metrics = self.metrics.clone(); @@ -595,30 +659,31 @@ impl ExecutionPlan for RepartitionExec { // Get existing ordering to use for merging let sort_exprs = self.sort_exprs().cloned().unwrap_or_default(); + let state = Arc::clone(&self.state); + if let Some(mut state) = state.try_lock() { + state.ensure_input_streams_initialized( + Arc::clone(&input), + metrics.clone(), + partitioning.partition_count(), + Arc::clone(&context), + )?; + } + let stream = futures::stream::once(async move { let num_input_partitions = input.output_partitioning().partition_count(); - let input_captured = Arc::clone(&input); - let metrics_captured = metrics.clone(); - let name_captured = name.clone(); - let context_captured = Arc::clone(&context); - let state = lazy_state - .get_or_init(|| async move { - Mutex::new(RepartitionExecState::new( - input_captured, - partitioning, - metrics_captured, - preserve_order, - name_captured, - context_captured, - )) - }) - .await; - // lock scope let (mut rx, reservation, abort_helper) = { // lock mutexes let mut state = state.lock(); + let state = state.consume_input_streams( + Arc::clone(&input), + metrics.clone(), + partitioning, + preserve_order, + name.clone(), + Arc::clone(&context), + )?; // now return stream for the specified *output* partition which will // read from the channel @@ -853,24 +918,17 @@ impl RepartitionExec { /// /// txs hold the output sending channels for each output partition async fn pull_from_input( - input: Arc<dyn ExecutionPlan>, - partition: usize, + mut stream: SendableRecordBatchStream, mut output_channels: HashMap< usize, (DistributionSender<MaybeBatch>, SharedMemoryReservation), >, partitioning: Partitioning, metrics: RepartitionMetrics, - context: Arc<TaskContext>, ) -> Result<()> { let mut partitioner = BatchPartitioner::try_new(partitioning, metrics.repartition_time.clone())?; - // execute the child operator - let timer = metrics.fetch_time.timer(); - let mut stream = input.execute(partition, context)?; - timer.done(); - // While there are still outputs to send to, keep pulling inputs let mut batches_until_yield = partitioner.num_partitions(); while !output_channels.is_empty() { @@ -1118,6 +1176,7 @@ mod tests { use datafusion_common_runtime::JoinSet; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use insta::assert_snapshot; + use itertools::Itertools; #[tokio::test] async fn one_to_many_round_robin() -> Result<()> { @@ -1298,15 +1357,9 @@ mod tests { let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - // Note: this should pass (the stream can be created) but the - // error when the input is executed should get passed back - let output_stream = exec.execute(0, task_ctx).unwrap(); - // Expect that an error is returned - let result_string = crate::common::collect(output_stream) - .await - .unwrap_err() - .to_string(); + let result_string = exec.execute(0, task_ctx).err().unwrap().to_string(); + assert!( result_string.contains("ErrorExec, unsurprisingly, errored in partition 0"), "actual: {result_string}" @@ -1496,7 +1549,14 @@ mod tests { }); let batches_with_drop = crate::common::collect(output_stream1).await.unwrap(); - assert_eq!(batches_without_drop, batches_with_drop); + fn sort(batch: Vec<RecordBatch>) -> Vec<RecordBatch> { + batch + .into_iter() + .sorted_by_key(|b| format!("{b:?}")) + .collect() + } + + assert_eq!(sort(batches_without_drop), sort(batches_with_drop)); } fn str_batches_to_vec(batches: &[RecordBatch]) -> Vec<&str> { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org