This is an automated email from the ASF dual-hosted git repository.

yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a79f332a0 Simplify sort streams (#2296)
a79f332a0 is described below

commit a79f332a0b87a307ac350fedd5c3d55dad7d5940
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Apr 21 02:42:52 2022 +0100

    Simplify sort streams (#2296)
---
 .../core/src/physical_plan/coalesce_partitions.rs  |  24 ++---
 datafusion/core/src/physical_plan/common.rs        |   6 +-
 datafusion/core/src/physical_plan/sorts/mod.rs     |  47 ---------
 .../physical_plan/sorts/sort_preserving_merge.rs   | 110 ++++++++-------------
 4 files changed, 53 insertions(+), 134 deletions(-)

diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs 
b/datafusion/core/src/physical_plan/coalesce_partitions.rs
index 3ecbd61f2..35f75db03 100644
--- a/datafusion/core/src/physical_plan/coalesce_partitions.rs
+++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs
@@ -22,8 +22,8 @@ use std::any::Any;
 use std::sync::Arc;
 use std::task::Poll;
 
-use futures::channel::mpsc;
 use futures::Stream;
+use tokio::sync::mpsc;
 
 use async_trait::async_trait;
 
@@ -40,7 +40,6 @@ use crate::physical_plan::{DisplayFormatType, ExecutionPlan, 
Partitioning};
 use super::SendableRecordBatchStream;
 use crate::execution::context::TaskContext;
 use crate::physical_plan::common::spawn_execution;
-use pin_project_lite::pin_project;
 
 /// Merge execution plan executes partitions in parallel and combines them 
into a single
 /// partition. No guarantees are made about the order of the resulting 
partition.
@@ -180,26 +179,23 @@ impl ExecutionPlan for CoalescePartitionsExec {
     }
 }
 
-pin_project! {
-    struct MergeStream {
-        schema: SchemaRef,
-        #[pin]
-        input: mpsc::Receiver<ArrowResult<RecordBatch>>,
-        baseline_metrics: BaselineMetrics,
-        drop_helper: AbortOnDropMany<()>,
-    }
+struct MergeStream {
+    schema: SchemaRef,
+    input: mpsc::Receiver<ArrowResult<RecordBatch>>,
+    baseline_metrics: BaselineMetrics,
+    #[allow(unused)]
+    drop_helper: AbortOnDropMany<()>,
 }
 
 impl Stream for MergeStream {
     type Item = ArrowResult<RecordBatch>;
 
     fn poll_next(
-        self: std::pin::Pin<&mut Self>,
+        mut self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Self::Item>> {
-        let this = self.project();
-        let poll = this.input.poll_next(cx);
-        this.baseline_metrics.record_poll(poll)
+        let poll = self.input.poll_recv(cx);
+        self.baseline_metrics.record_poll(poll)
     }
 }
 
diff --git a/datafusion/core/src/physical_plan/common.rs 
b/datafusion/core/src/physical_plan/common.rs
index b7313b9f2..68bd676dd 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -28,14 +28,14 @@ use arrow::error::ArrowError;
 use arrow::error::Result as ArrowResult;
 use arrow::ipc::writer::FileWriter;
 use arrow::record_batch::RecordBatch;
-use futures::channel::mpsc;
-use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt};
+use futures::{Future, Stream, StreamExt, TryStreamExt};
 use pin_project_lite::pin_project;
 use std::fs;
 use std::fs::{metadata, File};
 use std::path::{Path, PathBuf};
 use std::sync::Arc;
 use std::task::{Context, Poll};
+use tokio::sync::mpsc;
 use tokio::task::JoinHandle;
 
 /// Stream of record batches
@@ -174,7 +174,7 @@ fn build_file_list_recurse(
 /// Spawns a task to the tokio threadpool and writes its outputs to the 
provided mpsc sender
 pub(crate) fn spawn_execution(
     input: Arc<dyn ExecutionPlan>,
-    mut output: mpsc::Sender<ArrowResult<RecordBatch>>,
+    output: mpsc::Sender<ArrowResult<RecordBatch>>,
     partition: usize,
     context: Arc<TaskContext>,
 ) -> JoinHandle<()> {
diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs 
b/datafusion/core/src/physical_plan/sorts/mod.rs
index 818546f31..8d499be3a 100644
--- a/datafusion/core/src/physical_plan/sorts/mod.rs
+++ b/datafusion/core/src/physical_plan/sorts/mod.rs
@@ -22,19 +22,13 @@ use crate::error::{DataFusionError, Result};
 use crate::physical_plan::{PhysicalExpr, SendableRecordBatchStream};
 use arrow::array::{ArrayRef, DynComparator};
 use arrow::compute::SortOptions;
-use arrow::error::Result as ArrowResult;
 use arrow::record_batch::RecordBatch;
-use futures::channel::mpsc;
-use futures::stream::FusedStream;
-use futures::Stream;
 use hashbrown::HashMap;
 use parking_lot::RwLock;
 use std::borrow::BorrowMut;
 use std::cmp::Ordering;
 use std::fmt::{Debug, Formatter};
-use std::pin::Pin;
 use std::sync::Arc;
-use std::task::{Context, Poll};
 
 pub mod sort;
 pub mod sort_preserving_merge;
@@ -242,44 +236,3 @@ impl SortedStream {
         Self { stream, mem_used }
     }
 }
-
-#[derive(Debug)]
-enum StreamWrapper {
-    Receiver(mpsc::Receiver<ArrowResult<RecordBatch>>),
-    Stream(Option<SortedStream>),
-}
-
-impl Stream for StreamWrapper {
-    type Item = ArrowResult<RecordBatch>;
-
-    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Option<Self::Item>> {
-        match self.get_mut() {
-            StreamWrapper::Receiver(ref mut receiver) => 
Pin::new(receiver).poll_next(cx),
-            StreamWrapper::Stream(ref mut stream) => {
-                let inner = match stream {
-                    None => return Poll::Ready(None),
-                    Some(inner) => inner,
-                };
-
-                match Pin::new(&mut inner.stream).poll_next(cx) {
-                    Poll::Ready(msg) => {
-                        if msg.is_none() {
-                            *stream = None
-                        }
-                        Poll::Ready(msg)
-                    }
-                    Poll::Pending => Poll::Pending,
-                }
-            }
-        }
-    }
-}
-
-impl FusedStream for StreamWrapper {
-    fn is_terminated(&self) -> bool {
-        match self {
-            StreamWrapper::Receiver(receiver) => receiver.is_terminated(),
-            StreamWrapper::Stream(stream) => stream.is_none(),
-        }
-    }
-}
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4bc3606e0..6b8367a53 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -17,7 +17,6 @@
 
 //! Defines the sort preserving merge plan
 
-use crate::physical_plan::common::AbortOnDropMany;
 use crate::physical_plan::metrics::{
     ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
 };
@@ -25,7 +24,6 @@ use log::debug;
 use parking_lot::Mutex;
 use std::any::Any;
 use std::collections::{BinaryHeap, VecDeque};
-use std::fmt::Debug;
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
@@ -38,13 +36,14 @@ use arrow::{
     record_batch::RecordBatch,
 };
 use async_trait::async_trait;
-use futures::channel::mpsc;
-use futures::stream::FusedStream;
+use futures::stream::{Fuse, FusedStream};
 use futures::{Stream, StreamExt};
+use tokio::sync::mpsc;
 
 use crate::error::{DataFusionError, Result};
 use crate::execution::context::TaskContext;
-use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, 
StreamWrapper};
+use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream};
+use crate::physical_plan::stream::RecordBatchReceiverStream;
 use crate::physical_plan::{
     common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
     Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
@@ -174,6 +173,8 @@ impl ExecutionPlan for SortPreservingMergeExec {
             "Number of input partitions of  SortPreservingMergeExec::execute: 
{}",
             input_partitions
         );
+        let schema = self.schema();
+
         match input_partitions {
             0 => Err(DataFusionError::Internal(
                 "SortPreservingMergeExec requires at least one input partition"
@@ -186,7 +187,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
                 result
             }
             _ => {
-                let (receivers, join_handles) = (0..input_partitions)
+                let receivers = (0..input_partitions)
                     .into_iter()
                     .map(|part_i| {
                         let (sender, receiver) = mpsc::channel(1);
@@ -196,16 +197,23 @@ impl ExecutionPlan for SortPreservingMergeExec {
                             part_i,
                             context.clone(),
                         );
-                        (receiver, join_handle)
+
+                        SortedStream::new(
+                            RecordBatchReceiverStream::create(
+                                &schema,
+                                receiver,
+                                join_handle,
+                            ),
+                            0,
+                        )
                     })
-                    .unzip();
+                    .collect();
 
                 debug!("Done setting up sender-receiver for 
SortPreservingMergeExec::execute");
 
-                let result = 
Box::pin(SortPreservingMergeStream::new_from_receivers(
+                let result = 
Box::pin(SortPreservingMergeStream::new_from_streams(
                     receivers,
-                    AbortOnDropMany(join_handles),
-                    self.schema(),
+                    schema,
                     &self.expr,
                     tracking_metrics,
                     context.session_config().batch_size,
@@ -240,16 +248,23 @@ impl ExecutionPlan for SortPreservingMergeExec {
     }
 }
 
-#[derive(Debug)]
 struct MergingStreams {
     /// The sorted input streams to merge together
-    streams: Mutex<Vec<StreamWrapper>>,
+    streams: Mutex<Vec<Fuse<SendableRecordBatchStream>>>,
     /// number of streams
     num_streams: usize,
 }
 
+impl std::fmt::Debug for MergingStreams {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("MergingStreams")
+            .field("num_streams", &self.num_streams)
+            .finish()
+    }
+}
+
 impl MergingStreams {
-    fn new(input_streams: Vec<StreamWrapper>) -> Self {
+    fn new(input_streams: Vec<Fuse<SendableRecordBatchStream>>) -> Self {
         Self {
             num_streams: input_streams.len(),
             streams: Mutex::new(input_streams),
@@ -269,9 +284,6 @@ pub(crate) struct SortPreservingMergeStream {
     /// The sorted input streams to merge together
     streams: MergingStreams,
 
-    /// Drop helper for tasks feeding the input [`streams`](Self::streams)
-    _drop_helper: AbortOnDropMany<()>,
-
     /// For each input stream maintain a dequeue of RecordBatches
     ///
     /// Exhausted batches will be popped off the front once all
@@ -308,39 +320,6 @@ pub(crate) struct SortPreservingMergeStream {
 }
 
 impl SortPreservingMergeStream {
-    pub(crate) fn new_from_receivers(
-        receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
-        _drop_helper: AbortOnDropMany<()>,
-        schema: SchemaRef,
-        expressions: &[PhysicalSortExpr],
-        tracking_metrics: MemTrackingMetrics,
-        batch_size: usize,
-    ) -> Self {
-        debug!("Start SortPreservingMergeStream::new_from_receivers");
-        let stream_count = receivers.len();
-        let batches = (0..stream_count)
-            .into_iter()
-            .map(|_| VecDeque::new())
-            .collect();
-        let wrappers = 
receivers.into_iter().map(StreamWrapper::Receiver).collect();
-
-        SortPreservingMergeStream {
-            schema,
-            batches,
-            cursor_finished: vec![true; stream_count],
-            streams: MergingStreams::new(wrappers),
-            _drop_helper,
-            column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
-            sort_options: Arc::new(expressions.iter().map(|x| 
x.options).collect()),
-            tracking_metrics,
-            aborted: false,
-            in_progress: vec![],
-            next_batch_id: 0,
-            min_heap: BinaryHeap::with_capacity(stream_count),
-            batch_size,
-        }
-    }
-
     pub(crate) fn new_from_streams(
         streams: Vec<SortedStream>,
         schema: SchemaRef,
@@ -354,17 +333,13 @@ impl SortPreservingMergeStream {
             .map(|_| VecDeque::new())
             .collect();
         tracking_metrics.init_mem_used(streams.iter().map(|s| 
s.mem_used).sum());
-        let wrappers = streams
-            .into_iter()
-            .map(|s| StreamWrapper::Stream(Some(s)))
-            .collect();
+        let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect();
 
         Self {
             schema,
             batches,
             cursor_finished: vec![true; stream_count],
             streams: MergingStreams::new(wrappers),
-            _drop_helper: AbortOnDropMany(vec![]),
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
             sort_options: Arc::new(expressions.iter().map(|x| 
x.options).collect()),
             tracking_metrics,
@@ -638,7 +613,7 @@ mod tests {
     use super::*;
     use crate::prelude::{SessionConfig, SessionContext};
     use arrow::datatypes::{DataType, Field, Schema};
-    use futures::{FutureExt, SinkExt};
+    use futures::FutureExt;
     use tokio_stream::StreamExt;
 
     #[tokio::test]
@@ -1213,11 +1188,10 @@ mod tests {
             sorted_partitioned_input(sort.clone(), &[5, 7, 3], 
task_ctx.clone()).await;
 
         let partition_count = batches.output_partitioning().partition_count();
-        let mut join_handles = Vec::with_capacity(partition_count);
-        let mut receivers = Vec::with_capacity(partition_count);
+        let mut streams = Vec::with_capacity(partition_count);
 
         for partition in 0..partition_count {
-            let (mut sender, receiver) = mpsc::channel(1);
+            let (sender, receiver) = mpsc::channel(1);
             let mut stream = batches.execute(partition, 
task_ctx.clone()).await.unwrap();
             let join_handle = tokio::spawn(async move {
                 while let Some(batch) = stream.next().await {
@@ -1226,17 +1200,18 @@ mod tests {
                     
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
                 }
             });
-            join_handles.push(join_handle);
-            receivers.push(receiver);
+
+            streams.push(SortedStream::new(
+                RecordBatchReceiverStream::create(&schema, receiver, 
join_handle),
+                0,
+            ));
         }
 
         let metrics = ExecutionPlanMetricsSet::new();
         let tracking_metrics = MemTrackingMetrics::new(&metrics, 0);
 
-        let merge_stream = SortPreservingMergeStream::new_from_receivers(
-            receivers,
-            // Use empty vector since we want to use the join handles ourselves
-            AbortOnDropMany(vec![]),
+        let merge_stream = SortPreservingMergeStream::new_from_streams(
+            streams,
             batches.schema(),
             sort.as_slice(),
             tracking_metrics,
@@ -1245,11 +1220,6 @@ mod tests {
 
         let mut merged = 
common::collect(Box::pin(merge_stream)).await.unwrap();
 
-        // Propagate any errors
-        for join_handle in join_handles {
-            join_handle.await.unwrap();
-        }
-
         assert_eq!(merged.len(), 1);
         let merged = merged.remove(0);
         let basic = basic_sort(batches, sort.clone(), task_ctx.clone()).await;

Reply via email to