This is an automated email from the ASF dual-hosted git repository.
yjshen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a79f332a0 Simplify sort streams (#2296)
a79f332a0 is described below
commit a79f332a0b87a307ac350fedd5c3d55dad7d5940
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Thu Apr 21 02:42:52 2022 +0100
Simplify sort streams (#2296)
---
.../core/src/physical_plan/coalesce_partitions.rs | 24 ++---
datafusion/core/src/physical_plan/common.rs | 6 +-
datafusion/core/src/physical_plan/sorts/mod.rs | 47 ---------
.../physical_plan/sorts/sort_preserving_merge.rs | 110 ++++++++-------------
4 files changed, 53 insertions(+), 134 deletions(-)
diff --git a/datafusion/core/src/physical_plan/coalesce_partitions.rs
b/datafusion/core/src/physical_plan/coalesce_partitions.rs
index 3ecbd61f2..35f75db03 100644
--- a/datafusion/core/src/physical_plan/coalesce_partitions.rs
+++ b/datafusion/core/src/physical_plan/coalesce_partitions.rs
@@ -22,8 +22,8 @@ use std::any::Any;
use std::sync::Arc;
use std::task::Poll;
-use futures::channel::mpsc;
use futures::Stream;
+use tokio::sync::mpsc;
use async_trait::async_trait;
@@ -40,7 +40,6 @@ use crate::physical_plan::{DisplayFormatType, ExecutionPlan,
Partitioning};
use super::SendableRecordBatchStream;
use crate::execution::context::TaskContext;
use crate::physical_plan::common::spawn_execution;
-use pin_project_lite::pin_project;
/// Merge execution plan executes partitions in parallel and combines them
into a single
/// partition. No guarantees are made about the order of the resulting
partition.
@@ -180,26 +179,23 @@ impl ExecutionPlan for CoalescePartitionsExec {
}
}
-pin_project! {
- struct MergeStream {
- schema: SchemaRef,
- #[pin]
- input: mpsc::Receiver<ArrowResult<RecordBatch>>,
- baseline_metrics: BaselineMetrics,
- drop_helper: AbortOnDropMany<()>,
- }
+struct MergeStream {
+ schema: SchemaRef,
+ input: mpsc::Receiver<ArrowResult<RecordBatch>>,
+ baseline_metrics: BaselineMetrics,
+ #[allow(unused)]
+ drop_helper: AbortOnDropMany<()>,
}
impl Stream for MergeStream {
type Item = ArrowResult<RecordBatch>;
fn poll_next(
- self: std::pin::Pin<&mut Self>,
+ mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Self::Item>> {
- let this = self.project();
- let poll = this.input.poll_next(cx);
- this.baseline_metrics.record_poll(poll)
+ let poll = self.input.poll_recv(cx);
+ self.baseline_metrics.record_poll(poll)
}
}
diff --git a/datafusion/core/src/physical_plan/common.rs
b/datafusion/core/src/physical_plan/common.rs
index b7313b9f2..68bd676dd 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -28,14 +28,14 @@ use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
use arrow::ipc::writer::FileWriter;
use arrow::record_batch::RecordBatch;
-use futures::channel::mpsc;
-use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt};
+use futures::{Future, Stream, StreamExt, TryStreamExt};
use pin_project_lite::pin_project;
use std::fs;
use std::fs::{metadata, File};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::task::{Context, Poll};
+use tokio::sync::mpsc;
use tokio::task::JoinHandle;
/// Stream of record batches
@@ -174,7 +174,7 @@ fn build_file_list_recurse(
/// Spawns a task to the tokio threadpool and writes its outputs to the
provided mpsc sender
pub(crate) fn spawn_execution(
input: Arc<dyn ExecutionPlan>,
- mut output: mpsc::Sender<ArrowResult<RecordBatch>>,
+ output: mpsc::Sender<ArrowResult<RecordBatch>>,
partition: usize,
context: Arc<TaskContext>,
) -> JoinHandle<()> {
diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs
b/datafusion/core/src/physical_plan/sorts/mod.rs
index 818546f31..8d499be3a 100644
--- a/datafusion/core/src/physical_plan/sorts/mod.rs
+++ b/datafusion/core/src/physical_plan/sorts/mod.rs
@@ -22,19 +22,13 @@ use crate::error::{DataFusionError, Result};
use crate::physical_plan::{PhysicalExpr, SendableRecordBatchStream};
use arrow::array::{ArrayRef, DynComparator};
use arrow::compute::SortOptions;
-use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
-use futures::channel::mpsc;
-use futures::stream::FusedStream;
-use futures::Stream;
use hashbrown::HashMap;
use parking_lot::RwLock;
use std::borrow::BorrowMut;
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
-use std::pin::Pin;
use std::sync::Arc;
-use std::task::{Context, Poll};
pub mod sort;
pub mod sort_preserving_merge;
@@ -242,44 +236,3 @@ impl SortedStream {
Self { stream, mem_used }
}
}
-
-#[derive(Debug)]
-enum StreamWrapper {
- Receiver(mpsc::Receiver<ArrowResult<RecordBatch>>),
- Stream(Option<SortedStream>),
-}
-
-impl Stream for StreamWrapper {
- type Item = ArrowResult<RecordBatch>;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) ->
Poll<Option<Self::Item>> {
- match self.get_mut() {
- StreamWrapper::Receiver(ref mut receiver) =>
Pin::new(receiver).poll_next(cx),
- StreamWrapper::Stream(ref mut stream) => {
- let inner = match stream {
- None => return Poll::Ready(None),
- Some(inner) => inner,
- };
-
- match Pin::new(&mut inner.stream).poll_next(cx) {
- Poll::Ready(msg) => {
- if msg.is_none() {
- *stream = None
- }
- Poll::Ready(msg)
- }
- Poll::Pending => Poll::Pending,
- }
- }
- }
- }
-}
-
-impl FusedStream for StreamWrapper {
- fn is_terminated(&self) -> bool {
- match self {
- StreamWrapper::Receiver(receiver) => receiver.is_terminated(),
- StreamWrapper::Stream(stream) => stream.is_none(),
- }
- }
-}
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 4bc3606e0..6b8367a53 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -17,7 +17,6 @@
//! Defines the sort preserving merge plan
-use crate::physical_plan::common::AbortOnDropMany;
use crate::physical_plan::metrics::{
ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
};
@@ -25,7 +24,6 @@ use log::debug;
use parking_lot::Mutex;
use std::any::Any;
use std::collections::{BinaryHeap, VecDeque};
-use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
@@ -38,13 +36,14 @@ use arrow::{
record_batch::RecordBatch,
};
use async_trait::async_trait;
-use futures::channel::mpsc;
-use futures::stream::FusedStream;
+use futures::stream::{Fuse, FusedStream};
use futures::{Stream, StreamExt};
+use tokio::sync::mpsc;
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
-use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream,
StreamWrapper};
+use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream};
+use crate::physical_plan::stream::RecordBatchReceiverStream;
use crate::physical_plan::{
common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
@@ -174,6 +173,8 @@ impl ExecutionPlan for SortPreservingMergeExec {
"Number of input partitions of SortPreservingMergeExec::execute:
{}",
input_partitions
);
+ let schema = self.schema();
+
match input_partitions {
0 => Err(DataFusionError::Internal(
"SortPreservingMergeExec requires at least one input partition"
@@ -186,7 +187,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
result
}
_ => {
- let (receivers, join_handles) = (0..input_partitions)
+ let receivers = (0..input_partitions)
.into_iter()
.map(|part_i| {
let (sender, receiver) = mpsc::channel(1);
@@ -196,16 +197,23 @@ impl ExecutionPlan for SortPreservingMergeExec {
part_i,
context.clone(),
);
- (receiver, join_handle)
+
+ SortedStream::new(
+ RecordBatchReceiverStream::create(
+ &schema,
+ receiver,
+ join_handle,
+ ),
+ 0,
+ )
})
- .unzip();
+ .collect();
debug!("Done setting up sender-receiver for
SortPreservingMergeExec::execute");
- let result =
Box::pin(SortPreservingMergeStream::new_from_receivers(
+ let result =
Box::pin(SortPreservingMergeStream::new_from_streams(
receivers,
- AbortOnDropMany(join_handles),
- self.schema(),
+ schema,
&self.expr,
tracking_metrics,
context.session_config().batch_size,
@@ -240,16 +248,23 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}
-#[derive(Debug)]
struct MergingStreams {
/// The sorted input streams to merge together
- streams: Mutex<Vec<StreamWrapper>>,
+ streams: Mutex<Vec<Fuse<SendableRecordBatchStream>>>,
/// number of streams
num_streams: usize,
}
+impl std::fmt::Debug for MergingStreams {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("MergingStreams")
+ .field("num_streams", &self.num_streams)
+ .finish()
+ }
+}
+
impl MergingStreams {
- fn new(input_streams: Vec<StreamWrapper>) -> Self {
+ fn new(input_streams: Vec<Fuse<SendableRecordBatchStream>>) -> Self {
Self {
num_streams: input_streams.len(),
streams: Mutex::new(input_streams),
@@ -269,9 +284,6 @@ pub(crate) struct SortPreservingMergeStream {
/// The sorted input streams to merge together
streams: MergingStreams,
- /// Drop helper for tasks feeding the input [`streams`](Self::streams)
- _drop_helper: AbortOnDropMany<()>,
-
/// For each input stream maintain a dequeue of RecordBatches
///
/// Exhausted batches will be popped off the front once all
@@ -308,39 +320,6 @@ pub(crate) struct SortPreservingMergeStream {
}
impl SortPreservingMergeStream {
- pub(crate) fn new_from_receivers(
- receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
- _drop_helper: AbortOnDropMany<()>,
- schema: SchemaRef,
- expressions: &[PhysicalSortExpr],
- tracking_metrics: MemTrackingMetrics,
- batch_size: usize,
- ) -> Self {
- debug!("Start SortPreservingMergeStream::new_from_receivers");
- let stream_count = receivers.len();
- let batches = (0..stream_count)
- .into_iter()
- .map(|_| VecDeque::new())
- .collect();
- let wrappers =
receivers.into_iter().map(StreamWrapper::Receiver).collect();
-
- SortPreservingMergeStream {
- schema,
- batches,
- cursor_finished: vec![true; stream_count],
- streams: MergingStreams::new(wrappers),
- _drop_helper,
- column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
- sort_options: Arc::new(expressions.iter().map(|x|
x.options).collect()),
- tracking_metrics,
- aborted: false,
- in_progress: vec![],
- next_batch_id: 0,
- min_heap: BinaryHeap::with_capacity(stream_count),
- batch_size,
- }
- }
-
pub(crate) fn new_from_streams(
streams: Vec<SortedStream>,
schema: SchemaRef,
@@ -354,17 +333,13 @@ impl SortPreservingMergeStream {
.map(|_| VecDeque::new())
.collect();
tracking_metrics.init_mem_used(streams.iter().map(|s|
s.mem_used).sum());
- let wrappers = streams
- .into_iter()
- .map(|s| StreamWrapper::Stream(Some(s)))
- .collect();
+ let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect();
Self {
schema,
batches,
cursor_finished: vec![true; stream_count],
streams: MergingStreams::new(wrappers),
- _drop_helper: AbortOnDropMany(vec![]),
column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
sort_options: Arc::new(expressions.iter().map(|x|
x.options).collect()),
tracking_metrics,
@@ -638,7 +613,7 @@ mod tests {
use super::*;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::datatypes::{DataType, Field, Schema};
- use futures::{FutureExt, SinkExt};
+ use futures::FutureExt;
use tokio_stream::StreamExt;
#[tokio::test]
@@ -1213,11 +1188,10 @@ mod tests {
sorted_partitioned_input(sort.clone(), &[5, 7, 3],
task_ctx.clone()).await;
let partition_count = batches.output_partitioning().partition_count();
- let mut join_handles = Vec::with_capacity(partition_count);
- let mut receivers = Vec::with_capacity(partition_count);
+ let mut streams = Vec::with_capacity(partition_count);
for partition in 0..partition_count {
- let (mut sender, receiver) = mpsc::channel(1);
+ let (sender, receiver) = mpsc::channel(1);
let mut stream = batches.execute(partition,
task_ctx.clone()).await.unwrap();
let join_handle = tokio::spawn(async move {
while let Some(batch) = stream.next().await {
@@ -1226,17 +1200,18 @@ mod tests {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
});
- join_handles.push(join_handle);
- receivers.push(receiver);
+
+ streams.push(SortedStream::new(
+ RecordBatchReceiverStream::create(&schema, receiver,
join_handle),
+ 0,
+ ));
}
let metrics = ExecutionPlanMetricsSet::new();
let tracking_metrics = MemTrackingMetrics::new(&metrics, 0);
- let merge_stream = SortPreservingMergeStream::new_from_receivers(
- receivers,
- // Use empty vector since we want to use the join handles ourselves
- AbortOnDropMany(vec![]),
+ let merge_stream = SortPreservingMergeStream::new_from_streams(
+ streams,
batches.schema(),
sort.as_slice(),
tracking_metrics,
@@ -1245,11 +1220,6 @@ mod tests {
let mut merged =
common::collect(Box::pin(merge_stream)).await.unwrap();
- // Propagate any errors
- for join_handle in join_handles {
- join_handle.await.unwrap();
- }
-
assert_eq!(merged.len(), 1);
let merged = merged.remove(0);
let basic = basic_sort(batches, sort.clone(), task_ctx.clone()).await;