This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new e41711b29e Generify SortPreservingMerge (#5882) (#5879) (#5886)
e41711b29e is described below

commit e41711b29e72f2d9a8a25c1ed00ccf3ed369f3bc
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Apr 7 09:21:55 2023 +0100

    Generify SortPreservingMerge (#5882) (#5879) (#5886)
    
    * Generify SortPreservingMerge (#5882) (#5879)
    
    * Review feedback
---
 datafusion/core/src/physical_plan/sorts/builder.rs | 171 +++++++
 datafusion/core/src/physical_plan/sorts/cursor.rs  |  21 +
 datafusion/core/src/physical_plan/sorts/merge.rs   | 258 +++++++++++
 datafusion/core/src/physical_plan/sorts/mod.rs     |  24 +-
 datafusion/core/src/physical_plan/sorts/sort.rs    |  30 +-
 .../physical_plan/sorts/sort_preserving_merge.rs   | 489 +--------------------
 datafusion/core/src/physical_plan/sorts/stream.rs  | 145 ++++++
 7 files changed, 632 insertions(+), 506 deletions(-)

diff --git a/datafusion/core/src/physical_plan/sorts/builder.rs 
b/datafusion/core/src/physical_plan/sorts/builder.rs
new file mode 100644
index 0000000000..a1941963b6
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/builder.rs
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::sorts::index::RowIndex;
+use arrow::array::{make_array, MutableArrayData};
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use std::collections::VecDeque;
+
+/// Provides an API to incrementally build a [`RecordBatch`] from partitioned 
[`RecordBatch`]
+#[derive(Debug)]
+pub struct BatchBuilder {
+    /// The schema of the RecordBatches yielded by this stream
+    schema: SchemaRef,
+    /// For each input stream maintain a dequeue of RecordBatches
+    ///
+    /// Exhausted batches will be popped off the front once all
+    /// their rows have been yielded to the output
+    batches: Vec<VecDeque<RecordBatch>>,
+
+    /// The accumulated row indexes for the next record batch
+    indices: Vec<RowIndex>,
+}
+
+impl BatchBuilder {
+    /// Create a new [`BatchBuilder`] with the provided `stream_count` and 
`batch_size`
+    pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) -> 
Self {
+        let batches = (0..stream_count).map(|_| VecDeque::new()).collect();
+
+        Self {
+            schema,
+            batches,
+            indices: Vec::with_capacity(batch_size),
+        }
+    }
+
+    /// Append a new batch in `stream_idx`
+    pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) {
+        self.batches[stream_idx].push_back(batch)
+    }
+
+    /// Push `row_idx` from the most recently appended batch in `stream_idx`
+    pub fn push_row(&mut self, stream_idx: usize, row_idx: usize) {
+        let batch_idx = self.batches[stream_idx].len() - 1;
+        self.indices.push(RowIndex {
+            stream_idx,
+            batch_idx,
+            row_idx,
+        });
+    }
+
+    /// Returns the number of in-progress rows in this [`BatchBuilder`]
+    pub fn len(&self) -> usize {
+        self.indices.len()
+    }
+
+    /// Returns `true` if this [`BatchBuilder`] contains no in-progress rows
+    pub fn is_empty(&self) -> bool {
+        self.indices.is_empty()
+    }
+
+    /// Returns the schema of this [`BatchBuilder`]
+    pub fn schema(&self) -> &SchemaRef {
+        &self.schema
+    }
+
+    /// Drains the in_progress row indexes, and builds a new RecordBatch from 
them
+    ///
+    /// Will then drop any batches for which all rows have been yielded to the 
output
+    ///
+    /// Returns `None` if no pending rows
+    pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
+        if self.is_empty() {
+            return Ok(None);
+        }
+
+        // Mapping from stream index to the index of the first buffer from 
that stream
+        let mut buffer_idx = 0;
+        let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
+
+        for batches in &self.batches {
+            stream_to_buffer_idx.push(buffer_idx);
+            buffer_idx += batches.len();
+        }
+
+        let columns = self
+            .schema
+            .fields()
+            .iter()
+            .enumerate()
+            .map(|(column_idx, field)| {
+                let arrays = self
+                    .batches
+                    .iter()
+                    .flat_map(|batch| {
+                        batch.iter().map(|batch| 
batch.column(column_idx).data())
+                    })
+                    .collect();
+
+                let mut array_data = MutableArrayData::new(
+                    arrays,
+                    field.is_nullable(),
+                    self.indices.len(),
+                );
+
+                let first = &self.indices[0];
+                let mut buffer_idx =
+                    stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
+                let mut start_row_idx = first.row_idx;
+                let mut end_row_idx = start_row_idx + 1;
+
+                for row_index in self.indices.iter().skip(1) {
+                    let next_buffer_idx =
+                        stream_to_buffer_idx[row_index.stream_idx] + 
row_index.batch_idx;
+
+                    if next_buffer_idx == buffer_idx && row_index.row_idx == 
end_row_idx {
+                        // subsequent row in same batch
+                        end_row_idx += 1;
+                        continue;
+                    }
+
+                    // emit current batch of rows for current buffer
+                    array_data.extend(buffer_idx, start_row_idx, end_row_idx);
+
+                    // start new batch of rows
+                    buffer_idx = next_buffer_idx;
+                    start_row_idx = row_index.row_idx;
+                    end_row_idx = start_row_idx + 1;
+                }
+
+                // emit final batch of rows
+                array_data.extend(buffer_idx, start_row_idx, end_row_idx);
+                make_array(array_data.freeze())
+            })
+            .collect();
+
+        self.indices.clear();
+
+        // New cursors are only created once the previous cursor for the stream
+        // is finished. This means all remaining rows from all but the last 
batch
+        // for each stream have been yielded to the newly created record batch
+        //
+        // Additionally as `in_progress` has been drained, there are no longer
+        // any RowIndex's reliant on the batch indexes
+        //
+        // We can therefore drop all but the last batch for each stream
+        for batches in &mut self.batches {
+            if batches.len() > 1 {
+                // Drain all but the last batch
+                batches.drain(0..(batches.len() - 1));
+            }
+        }
+
+        Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?))
+    }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs 
b/datafusion/core/src/physical_plan/sorts/cursor.rs
index e52544cf50..8ab2acdda4 100644
--- a/datafusion/core/src/physical_plan/sorts/cursor.rs
+++ b/datafusion/core/src/physical_plan/sorts/cursor.rs
@@ -110,3 +110,24 @@ impl Ord for SortKeyCursor {
         }
     }
 }
+
+/// A cursor into a sorted batch of rows
+pub trait Cursor: Ord {
+    /// Returns true if there are no more rows in this cursor
+    fn is_finished(&self) -> bool;
+
+    /// Advance the cursor, returning the previous row index
+    ///
+    /// Returns `None` if [`Self::is_finished`]
+    fn advance(&mut self) -> Option<usize>;
+}
+
+impl Cursor for SortKeyCursor {
+    fn is_finished(&self) -> bool {
+        self.is_finished()
+    }
+
+    fn advance(&mut self) -> Option<usize> {
+        (!self.is_finished()).then(|| self.advance())
+    }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs 
b/datafusion/core/src/physical_plan/sorts/merge.rs
new file mode 100644
index 0000000000..d6b43db82b
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::metrics::MemTrackingMetrics;
+use crate::physical_plan::sorts::builder::BatchBuilder;
+use crate::physical_plan::sorts::cursor::Cursor;
+use crate::physical_plan::sorts::stream::{PartitionedStream, 
SortKeyCursorStream};
+use crate::physical_plan::{
+    PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
+};
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use futures::Stream;
+use std::pin::Pin;
+use std::task::{ready, Context, Poll};
+
+/// Perform a streaming merge of [`SendableRecordBatchStream`]
+pub(crate) fn streaming_merge(
+    streams: Vec<SendableRecordBatchStream>,
+    schema: SchemaRef,
+    expressions: &[PhysicalSortExpr],
+    tracking_metrics: MemTrackingMetrics,
+    batch_size: usize,
+) -> Result<SendableRecordBatchStream> {
+    let streams = SortKeyCursorStream::try_new(schema.as_ref(), expressions, 
streams)?;
+
+    Ok(Box::pin(SortPreservingMergeStream::new(
+        Box::new(streams),
+        schema,
+        tracking_metrics,
+        batch_size,
+    )))
+}
+
+/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
+type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C, 
RecordBatch)>>>;
+
+#[derive(Debug)]
+struct SortPreservingMergeStream<C> {
+    in_progress: BatchBuilder,
+
+    /// The sorted input streams to merge together
+    streams: CursorStream<C>,
+
+    /// used to record execution metrics
+    tracking_metrics: MemTrackingMetrics,
+
+    /// If the stream has encountered an error
+    aborted: bool,
+
+    /// A loser tree that always produces the minimum cursor
+    ///
+    /// Node 0 stores the top winner, Nodes 1..num_streams store
+    /// the loser nodes
+    ///
+    /// This implements a "Tournament Tree" (aka Loser Tree) to keep
+    /// track of the current smallest element at the top. When the top
+    /// record is taken, the tree structure is not modified, and only
+    /// the path from bottom to top is visited, keeping the number of
+    /// comparisons close to the theoretical limit of `log(S)`.
+    ///
+    /// reference: 
<https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
+    loser_tree: Vec<usize>,
+
+    /// If the most recently yielded overall winner has been replaced
+    /// within the loser tree. A value of `false` indicates that the
+    /// overall winner has been yielded but the loser tree has not
+    /// been updated
+    loser_tree_adjusted: bool,
+
+    /// target batch size
+    batch_size: usize,
+
+    /// Vector that holds cursors for each non-exhausted input partition
+    cursors: Vec<Option<C>>,
+}
+
+impl<C: Cursor> SortPreservingMergeStream<C> {
+    fn new(
+        streams: CursorStream<C>,
+        schema: SchemaRef,
+        tracking_metrics: MemTrackingMetrics,
+        batch_size: usize,
+    ) -> Self {
+        let stream_count = streams.partitions();
+
+        Self {
+            in_progress: BatchBuilder::new(schema, stream_count, batch_size),
+            streams,
+            tracking_metrics,
+            aborted: false,
+            cursors: (0..stream_count).map(|_| None).collect(),
+            loser_tree: vec![],
+            loser_tree_adjusted: false,
+            batch_size,
+        }
+    }
+
+    /// If the stream at the given index is not exhausted, and the last cursor 
for the
+    /// stream is finished, poll the stream for the next RecordBatch and 
create a new
+    /// cursor for the stream from the returned result
+    fn maybe_poll_stream(
+        &mut self,
+        cx: &mut Context<'_>,
+        idx: usize,
+    ) -> Poll<Result<()>> {
+        if self.cursors[idx]
+            .as_ref()
+            .map(|cursor| !cursor.is_finished())
+            .unwrap_or(false)
+        {
+            // Cursor is not finished - don't need a new RecordBatch yet
+            return Poll::Ready(Ok(()));
+        }
+
+        match futures::ready!(self.streams.poll_next(cx, idx)) {
+            None => Poll::Ready(Ok(())),
+            Some(Err(e)) => Poll::Ready(Err(e)),
+            Some(Ok((cursor, batch))) => {
+                self.cursors[idx] = Some(cursor);
+                self.in_progress.push_batch(idx, batch);
+                Poll::Ready(Ok(()))
+            }
+        }
+    }
+
+    fn poll_next_inner(
+        &mut self,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        if self.aborted {
+            return Poll::Ready(None);
+        }
+        // try to initialize the loser tree
+        if self.loser_tree.is_empty() {
+            // Ensure all non-exhausted streams have a cursor from which
+            // rows can be pulled
+            for i in 0..self.streams.partitions() {
+                if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
+                    self.aborted = true;
+                    return Poll::Ready(Some(Err(e)));
+                }
+            }
+            self.init_loser_tree();
+        }
+
+        // NB timer records time taken on drop, so there are no
+        // calls to `timer.done()` below.
+        let elapsed_compute = self.tracking_metrics.elapsed_compute().clone();
+        let _timer = elapsed_compute.timer();
+
+        loop {
+            // Adjust the loser tree if necessary, returning control if needed
+            if !self.loser_tree_adjusted {
+                let winner = self.loser_tree[0];
+                if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
+                    self.aborted = true;
+                    return Poll::Ready(Some(Err(e)));
+                }
+                self.update_loser_tree();
+            }
+
+            let stream_idx = self.loser_tree[0];
+            let cursor = self.cursors[stream_idx].as_mut();
+            if let Some(row_idx) = cursor.and_then(Cursor::advance) {
+                self.loser_tree_adjusted = false;
+                self.in_progress.push_row(stream_idx, row_idx);
+                if self.in_progress.len() < self.batch_size {
+                    continue;
+                }
+            }
+
+            return 
Poll::Ready(self.in_progress.build_record_batch().transpose());
+        }
+    }
+
+    /// Returns `true` if the cursor at index `a` is greater than at index `b`
+    #[inline]
+    fn is_gt(&self, a: usize, b: usize) -> bool {
+        match (&self.cursors[a], &self.cursors[b]) {
+            (None, _) => true,
+            (_, None) => false,
+            (Some(a), Some(b)) => b < a,
+        }
+    }
+
+    /// Attempts to initialize the loser tree with one value from each
+    /// non exhausted input, if possible
+    fn init_loser_tree(&mut self) {
+        // Init loser tree
+        self.loser_tree = vec![usize::MAX; self.cursors.len()];
+        for i in 0..self.cursors.len() {
+            let mut winner = i;
+            let mut cmp_node = (self.cursors.len() + i) / 2;
+            while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
+                let challenger = self.loser_tree[cmp_node];
+                if self.is_gt(winner, challenger) {
+                    self.loser_tree[cmp_node] = winner;
+                    winner = challenger;
+                }
+
+                cmp_node /= 2;
+            }
+            self.loser_tree[cmp_node] = winner;
+        }
+        self.loser_tree_adjusted = true;
+    }
+
+    /// Attempts to update the loser tree, following winner replacement, if 
possible
+    fn update_loser_tree(&mut self) {
+        let mut winner = self.loser_tree[0];
+        // Replace overall winner by walking tree of losers
+        let mut cmp_node = (self.cursors.len() + winner) / 2;
+        while cmp_node != 0 {
+            let challenger = self.loser_tree[cmp_node];
+            if self.is_gt(winner, challenger) {
+                self.loser_tree[cmp_node] = winner;
+                winner = challenger;
+            }
+            cmp_node /= 2;
+        }
+        self.loser_tree[0] = winner;
+        self.loser_tree_adjusted = true;
+    }
+}
+
+impl<C: Cursor + Unpin> Stream for SortPreservingMergeStream<C> {
+    type Item = Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: Pin<&mut Self>,
+        cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        let poll = self.poll_next_inner(cx);
+        self.tracking_metrics.record_poll(poll)
+    }
+}
+
+impl<C: Cursor + Unpin> RecordBatchStream for SortPreservingMergeStream<C> {
+    fn schema(&self) -> SchemaRef {
+        self.in_progress.schema().clone()
+    }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs 
b/datafusion/core/src/physical_plan/sorts/mod.rs
index db6ab5c604..cd5dae27dc 100644
--- a/datafusion/core/src/physical_plan/sorts/mod.rs
+++ b/datafusion/core/src/physical_plan/sorts/mod.rs
@@ -17,30 +17,14 @@
 
 //! Sort functionalities
 
-use crate::physical_plan::SendableRecordBatchStream;
-use std::fmt::{Debug, Formatter};
-
+mod builder;
 mod cursor;
 mod index;
+mod merge;
 pub mod sort;
 pub mod sort_preserving_merge;
+mod stream;
 
 pub use cursor::SortKeyCursor;
 pub use index::RowIndex;
-
-pub(crate) struct SortedStream {
-    stream: SendableRecordBatchStream,
-    mem_used: usize,
-}
-
-impl Debug for SortedStream {
-    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
-        write!(f, "InMemSorterStream")
-    }
-}
-
-impl SortedStream {
-    pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) -> 
Self {
-        Self { stream, mem_used }
-    }
-}
+pub(crate) use merge::streaming_merge;
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs 
b/datafusion/core/src/physical_plan/sorts/sort.rs
index c3fc06206c..1eceb02c2d 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -30,8 +30,7 @@ use crate::physical_plan::expressions::PhysicalSortExpr;
 use crate::physical_plan::metrics::{
     BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet,
 };
-use 
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
-use crate::physical_plan::sorts::SortedStream;
+use crate::physical_plan::sorts::merge::streaming_merge;
 use crate::physical_plan::stream::{RecordBatchReceiverStream, 
RecordBatchStreamAdapter};
 use crate::physical_plan::{
     DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, 
Partitioning,
@@ -169,37 +168,40 @@ impl ExternalSorter {
         let batch_size = self.session_config.batch_size();
 
         if self.spilled_before() {
-            let tracking_metrics = self
+            let intermediate_metrics = self
                 .metrics_set
                 .new_intermediate_tracking(self.partition_id, 
&self.runtime.memory_pool);
-            let mut streams: Vec<SortedStream> = vec![];
+            let mut merge_metrics = self
+                .metrics_set
+                .new_final_tracking(self.partition_id, 
&self.runtime.memory_pool);
+
+            let mut streams = vec![];
             if !self.in_mem_batches.is_empty() {
                 let in_mem_stream = in_mem_partial_sort(
                     &mut self.in_mem_batches,
                     self.schema.clone(),
                     &self.expr,
                     batch_size,
-                    tracking_metrics,
+                    intermediate_metrics,
                     self.fetch,
                 )?;
-                let prev_used = self.reservation.free();
-                streams.push(SortedStream::new(in_mem_stream, prev_used));
+                // TODO: More accurate, dynamic memory accounting (#5885)
+                merge_metrics.init_mem_used(self.reservation.free());
+                streams.push(in_mem_stream);
             }
 
             for spill in self.spills.drain(..) {
                 let stream = read_spill_as_stream(spill, self.schema.clone())?;
-                streams.push(SortedStream::new(stream, 0));
+                streams.push(stream);
             }
-            let tracking_metrics = self
-                .metrics_set
-                .new_final_tracking(self.partition_id, 
&self.runtime.memory_pool);
-            Ok(Box::pin(SortPreservingMergeStream::new_from_streams(
+
+            streaming_merge(
                 streams,
                 self.schema.clone(),
                 &self.expr,
-                tracking_metrics,
+                merge_metrics,
                 self.session_config.batch_size(),
-            )?))
+            )
         } else if !self.in_mem_batches.is_empty() {
             let tracking_metrics = self
                 .metrics_set
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 14204ef3b4..98cfb24c92 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -18,19 +18,9 @@
 //! Defines the sort preserving merge plan
 
 use std::any::Any;
-use std::collections::VecDeque;
-use std::pin::Pin;
 use std::sync::Arc;
-use std::task::{Context, Poll};
 
-use arrow::row::{RowConverter, SortField};
-use arrow::{
-    array::{make_array as make_arrow_array, MutableArrayData},
-    datatypes::SchemaRef,
-    record_batch::RecordBatch,
-};
-use futures::stream::{Fuse, FusedStream};
-use futures::{ready, Stream, StreamExt};
+use arrow::datatypes::SchemaRef;
 use log::debug;
 use tokio::sync::mpsc;
 
@@ -39,12 +29,11 @@ use crate::execution::context::TaskContext;
 use crate::physical_plan::metrics::{
     ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
 };
-use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream};
+use crate::physical_plan::sorts::streaming_merge;
 use crate::physical_plan::stream::RecordBatchReceiverStream;
 use crate::physical_plan::{
     common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
-    Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
-    SendableRecordBatchStream, Statistics,
+    Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
 };
 use datafusion_physical_expr::{
     make_sort_requirements_from_exprs, EquivalenceProperties, 
PhysicalSortRequirement,
@@ -206,34 +195,27 @@ impl ExecutionPlan for SortPreservingMergeExec {
                                 context.clone(),
                             );
 
-                            SortedStream::new(
-                                RecordBatchReceiverStream::create(
-                                    &schema,
-                                    receiver,
-                                    join_handle,
-                                ),
-                                0,
+                            RecordBatchReceiverStream::create(
+                                &schema,
+                                receiver,
+                                join_handle,
                             )
                         })
                         .collect(),
                     Err(_) => (0..input_partitions)
-                        .map(|partition| {
-                            let stream =
-                                self.input.execute(partition, 
context.clone())?;
-                            Ok(SortedStream::new(stream, 0))
-                        })
+                        .map(|partition| self.input.execute(partition, 
context.clone()))
                         .collect::<Result<_>>()?,
                 };
 
                 debug!("Done setting up sender-receiver for 
SortPreservingMergeExec::execute");
 
-                let result = 
Box::pin(SortPreservingMergeStream::new_from_streams(
+                let result = streaming_merge(
                     receivers,
                     schema,
                     &self.expr,
                     tracking_metrics,
                     context.session_config().batch_size(),
-                )?);
+                )?;
 
                 debug!("Got stream result from 
SortPreservingMergeStream::new_from_receivers");
 
@@ -264,445 +246,6 @@ impl ExecutionPlan for SortPreservingMergeExec {
     }
 }
 
-struct MergingStreams {
-    /// The sorted input streams to merge together
-    streams: Vec<Fuse<SendableRecordBatchStream>>,
-    /// number of streams
-    num_streams: usize,
-}
-
-impl std::fmt::Debug for MergingStreams {
-    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
-        f.debug_struct("MergingStreams")
-            .field("num_streams", &self.num_streams)
-            .finish()
-    }
-}
-
-impl MergingStreams {
-    fn new(input_streams: Vec<Fuse<SendableRecordBatchStream>>) -> Self {
-        Self {
-            num_streams: input_streams.len(),
-            streams: input_streams,
-        }
-    }
-
-    fn num_streams(&self) -> usize {
-        self.num_streams
-    }
-}
-
-#[derive(Debug)]
-pub(crate) struct SortPreservingMergeStream {
-    /// The schema of the RecordBatches yielded by this stream
-    schema: SchemaRef,
-
-    /// The sorted input streams to merge together
-    streams: MergingStreams,
-
-    /// For each input stream maintain a dequeue of RecordBatches
-    ///
-    /// Exhausted batches will be popped off the front once all
-    /// their rows have been yielded to the output
-    batches: Vec<VecDeque<RecordBatch>>,
-
-    /// The accumulated row indexes for the next record batch
-    in_progress: Vec<RowIndex>,
-
-    /// The physical expressions to sort by
-    column_expressions: Vec<Arc<dyn PhysicalExpr>>,
-
-    /// used to record execution metrics
-    tracking_metrics: MemTrackingMetrics,
-
-    /// If the stream has encountered an error
-    aborted: bool,
-
-    /// Vector that holds all [`SortKeyCursor`]s
-    cursors: Vec<Option<SortKeyCursor>>,
-
-    /// A loser tree that always produces the minimum cursor
-    ///
-    /// Node 0 stores the top winner, Nodes 1..num_streams store
-    /// the loser nodes
-    ///
-    /// This implements a "Tournament Tree" (aka Loser Tree) to keep
-    /// track of the current smallest element at the top. When the top
-    /// record is taken, the tree structure is not modified, and only
-    /// the path from bottom to top is visited, keeping the number of
-    /// comparisons close to the theoretical limit of `log(S)`.
-    ///
-    /// reference: 
<https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
-    loser_tree: Vec<usize>,
-
-    /// If the most recently yielded overall winner has been replaced
-    /// within the loser tree. A value of `false` indicates that the
-    /// overall winner has been yielded but the loser tree has not
-    /// been updated
-    loser_tree_adjusted: bool,
-
-    /// target batch size
-    batch_size: usize,
-
-    /// row converter
-    row_converter: RowConverter,
-}
-
-impl SortPreservingMergeStream {
-    pub(crate) fn new_from_streams(
-        streams: Vec<SortedStream>,
-        schema: SchemaRef,
-        expressions: &[PhysicalSortExpr],
-        mut tracking_metrics: MemTrackingMetrics,
-        batch_size: usize,
-    ) -> Result<Self> {
-        let stream_count = streams.len();
-        let batches = (0..stream_count).map(|_| VecDeque::new()).collect();
-        tracking_metrics.init_mem_used(streams.iter().map(|s| 
s.mem_used).sum());
-        let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect();
-
-        let sort_fields = expressions
-            .iter()
-            .map(|expr| {
-                let data_type = expr.expr.data_type(&schema)?;
-                Ok(SortField::new_with_options(data_type, expr.options))
-            })
-            .collect::<Result<Vec<_>>>()?;
-        let row_converter = RowConverter::new(sort_fields)?;
-
-        Ok(Self {
-            schema,
-            batches,
-            streams: MergingStreams::new(wrappers),
-            column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
-            tracking_metrics,
-            aborted: false,
-            in_progress: vec![],
-            cursors: (0..stream_count).map(|_| None).collect(),
-            loser_tree: Vec::with_capacity(stream_count),
-            loser_tree_adjusted: false,
-            batch_size,
-            row_converter,
-        })
-    }
-
-    /// If the stream at the given index is not exhausted, and the last cursor 
for the
-    /// stream is finished, poll the stream for the next RecordBatch and 
create a new
-    /// cursor for the stream from the returned result
-    fn maybe_poll_stream(
-        &mut self,
-        cx: &mut Context<'_>,
-        idx: usize,
-    ) -> Poll<Result<()>> {
-        if self.cursors[idx]
-            .as_ref()
-            .map(|cursor| !cursor.is_finished())
-            .unwrap_or(false)
-        {
-            // Cursor is not finished - don't need a new RecordBatch yet
-            return Poll::Ready(Ok(()));
-        }
-        let mut empty_batch = false;
-        {
-            let stream = &mut self.streams.streams[idx];
-            if stream.is_terminated() {
-                return Poll::Ready(Ok(()));
-            }
-
-            // Fetch a new input record and create a cursor from it
-            match futures::ready!(stream.poll_next_unpin(cx)) {
-                None => return Poll::Ready(Ok(())),
-                Some(Err(e)) => {
-                    return Poll::Ready(Err(e));
-                }
-                Some(Ok(batch)) => {
-                    if batch.num_rows() > 0 {
-                        let cols = self
-                            .column_expressions
-                            .iter()
-                            .map(|expr| {
-                                
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))
-                            })
-                            .collect::<Result<Vec<_>>>()?;
-
-                        let rows = match 
self.row_converter.convert_columns(&cols) {
-                            Ok(rows) => rows,
-                            Err(e) => {
-                                return 
Poll::Ready(Err(DataFusionError::ArrowError(e)));
-                            }
-                        };
-
-                        self.cursors[idx] = Some(SortKeyCursor::new(idx, 
rows));
-                        self.batches[idx].push_back(batch)
-                    } else {
-                        empty_batch = true;
-                    }
-                }
-            }
-        }
-
-        if empty_batch {
-            self.maybe_poll_stream(cx, idx)
-        } else {
-            Poll::Ready(Ok(()))
-        }
-    }
-
-    /// Drains the in_progress row indexes, and builds a new RecordBatch from 
them
-    ///
-    /// Will then drop any batches for which all rows have been yielded to the 
output
-    fn build_record_batch(&mut self) -> Result<RecordBatch> {
-        // Mapping from stream index to the index of the first buffer from 
that stream
-        let mut buffer_idx = 0;
-        let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
-
-        for batches in &self.batches {
-            stream_to_buffer_idx.push(buffer_idx);
-            buffer_idx += batches.len();
-        }
-
-        let columns = self
-            .schema
-            .fields()
-            .iter()
-            .enumerate()
-            .map(|(column_idx, field)| {
-                let arrays = self
-                    .batches
-                    .iter()
-                    .flat_map(|batch| {
-                        batch.iter().map(|batch| 
batch.column(column_idx).data())
-                    })
-                    .collect();
-
-                let mut array_data = MutableArrayData::new(
-                    arrays,
-                    field.is_nullable(),
-                    self.in_progress.len(),
-                );
-
-                if self.in_progress.is_empty() {
-                    return make_arrow_array(array_data.freeze());
-                }
-
-                let first = &self.in_progress[0];
-                let mut buffer_idx =
-                    stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
-                let mut start_row_idx = first.row_idx;
-                let mut end_row_idx = start_row_idx + 1;
-
-                for row_index in self.in_progress.iter().skip(1) {
-                    let next_buffer_idx =
-                        stream_to_buffer_idx[row_index.stream_idx] + 
row_index.batch_idx;
-
-                    if next_buffer_idx == buffer_idx && row_index.row_idx == 
end_row_idx {
-                        // subsequent row in same batch
-                        end_row_idx += 1;
-                        continue;
-                    }
-
-                    // emit current batch of rows for current buffer
-                    array_data.extend(buffer_idx, start_row_idx, end_row_idx);
-
-                    // start new batch of rows
-                    buffer_idx = next_buffer_idx;
-                    start_row_idx = row_index.row_idx;
-                    end_row_idx = start_row_idx + 1;
-                }
-
-                // emit final batch of rows
-                array_data.extend(buffer_idx, start_row_idx, end_row_idx);
-                make_arrow_array(array_data.freeze())
-            })
-            .collect();
-
-        self.in_progress.clear();
-
-        // New cursors are only created once the previous cursor for the stream
-        // is finished. This means all remaining rows from all but the last 
batch
-        // for each stream have been yielded to the newly created record batch
-        //
-        // Additionally as `in_progress` has been drained, there are no longer
-        // any RowIndex's reliant on the batch indexes
-        //
-        // We can therefore drop all but the last batch for each stream
-        for batches in &mut self.batches {
-            if batches.len() > 1 {
-                // Drain all but the last batch
-                batches.drain(0..(batches.len() - 1));
-            }
-        }
-
-        RecordBatch::try_new(self.schema.clone(), columns).map_err(Into::into)
-    }
-}
-
-impl Stream for SortPreservingMergeStream {
-    type Item = Result<RecordBatch>;
-
-    fn poll_next(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Option<Self::Item>> {
-        let poll = self.poll_next_inner(cx);
-        self.tracking_metrics.record_poll(poll)
-    }
-}
-
-impl SortPreservingMergeStream {
-    #[inline]
-    fn poll_next_inner(
-        self: &mut Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Option<Result<RecordBatch>>> {
-        if self.aborted {
-            return Poll::Ready(None);
-        }
-        // try to initialize the loser tree
-        if let Err(e) = ready!(self.init_loser_tree(cx)) {
-            return Poll::Ready(Some(Err(e)));
-        }
-
-        // NB timer records time taken on drop, so there are no
-        // calls to `timer.done()` below.
-        let elapsed_compute = self.tracking_metrics.elapsed_compute().clone();
-        let _timer = elapsed_compute.timer();
-
-        loop {
-            // Adjust the loser tree if necessary, returning control if needed
-            if let Err(e) = ready!(self.update_loser_tree(cx)) {
-                return Poll::Ready(Some(Err(e)));
-            }
-
-            let min_cursor_idx = self.loser_tree[0];
-            let next = self.cursors[min_cursor_idx]
-                .as_mut()
-                .filter(|cursor| !cursor.is_finished())
-                .map(|cursor| (cursor.stream_idx(), cursor.advance()));
-
-            if let Some((stream_idx, row_idx)) = next {
-                self.loser_tree_adjusted = false;
-                let batch_idx = self.batches[stream_idx].len() - 1;
-                self.in_progress.push(RowIndex {
-                    stream_idx,
-                    batch_idx,
-                    row_idx,
-                });
-                if self.in_progress.len() == self.batch_size {
-                    return Poll::Ready(Some(self.build_record_batch()));
-                }
-            } else if !self.in_progress.is_empty() {
-                return Poll::Ready(Some(self.build_record_batch()));
-            } else {
-                return Poll::Ready(None);
-            }
-        }
-    }
-
-    /// Attempts to initialize the loser tree with one value from each
-    /// non exhausted input, if possible.
-    ///
-    /// Returns
-    /// * Poll::Pending when more data is needed
-    /// * Poll::Ready(Ok()) on success
-    /// * Poll::Ready(Err..) if any of the inputs  errored
-    #[inline]
-    fn init_loser_tree(
-        self: &mut Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Result<()>> {
-        let num_streams = self.streams.num_streams();
-
-        if !self.loser_tree.is_empty() {
-            return Poll::Ready(Ok(()));
-        }
-
-        // Ensure all non-exhausted streams have a cursor from which
-        // rows can be pulled
-        for i in 0..num_streams {
-            if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
-                self.aborted = true;
-                return Poll::Ready(Err(e));
-            }
-        }
-
-        // Init loser tree
-        self.loser_tree.resize(num_streams, usize::MAX);
-        for i in 0..num_streams {
-            let mut winner = i;
-            let mut cmp_node = (num_streams + i) / 2;
-            while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
-                let challenger = self.loser_tree[cmp_node];
-                let challenger_win =
-                    match (&self.cursors[winner], &self.cursors[challenger]) {
-                        (None, _) => true,
-                        (_, None) => false,
-                        (Some(winner), Some(challenger)) => challenger < 
winner,
-                    };
-
-                if challenger_win {
-                    self.loser_tree[cmp_node] = winner;
-                    winner = challenger;
-                }
-
-                cmp_node /= 2;
-            }
-            self.loser_tree[cmp_node] = winner;
-        }
-        self.loser_tree_adjusted = true;
-        Poll::Ready(Ok(()))
-    }
-
-    /// Attempts to updated the loser tree, if possible
-    ///
-    /// Returns
-    /// * Poll::Pending when the winning unput was not ready
-    /// * Poll::Ready(Ok()) on success
-    /// * Poll::Ready(Err..) if any of the winning input erroed
-    #[inline]
-    fn update_loser_tree(
-        self: &mut Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Result<()>> {
-        if self.loser_tree_adjusted {
-            return Poll::Ready(Ok(()));
-        }
-
-        let num_streams = self.streams.num_streams();
-        let mut winner = self.loser_tree[0];
-        if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
-            self.aborted = true;
-            return Poll::Ready(Err(e));
-        }
-
-        // Replace overall winner by walking tree of losers
-        let mut cmp_node = (num_streams + winner) / 2;
-        while cmp_node != 0 {
-            let challenger = self.loser_tree[cmp_node];
-            let challenger_win = match (&self.cursors[winner], 
&self.cursors[challenger])
-            {
-                (None, _) => true,
-                (_, None) => false,
-                (Some(winner), Some(challenger)) => challenger < winner,
-            };
-            if challenger_win {
-                self.loser_tree[cmp_node] = winner;
-                winner = challenger;
-            }
-            cmp_node /= 2;
-        }
-        self.loser_tree[0] = winner;
-        self.loser_tree_adjusted = true;
-        Poll::Ready(Ok(()))
-    }
-}
-
-impl RecordBatchStream for SortPreservingMergeStream {
-    fn schema(&self) -> SchemaRef {
-        self.schema.clone()
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use std::iter::FromIterator;
@@ -710,6 +253,7 @@ mod tests {
     use arrow::array::ArrayRef;
     use arrow::compute::SortOptions;
     use arrow::datatypes::{DataType, Field, Schema};
+    use arrow::record_batch::RecordBatch;
     use futures::FutureExt;
     use tokio_stream::StreamExt;
 
@@ -1284,9 +828,10 @@ mod tests {
                 }
             });
 
-            streams.push(SortedStream::new(
-                RecordBatchReceiverStream::create(&schema, receiver, 
join_handle),
-                0,
+            streams.push(RecordBatchReceiverStream::create(
+                &schema,
+                receiver,
+                join_handle,
             ));
         }
 
@@ -1294,7 +839,7 @@ mod tests {
         let tracking_metrics =
             MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0);
 
-        let merge_stream = SortPreservingMergeStream::new_from_streams(
+        let merge_stream = streaming_merge(
             streams,
             batches.schema(),
             sort.as_slice(),
@@ -1303,7 +848,7 @@ mod tests {
         )
         .unwrap();
 
-        let mut merged = 
common::collect(Box::pin(merge_stream)).await.unwrap();
+        let mut merged = common::collect(merge_stream).await.unwrap();
 
         assert_eq!(merged.len(), 1);
         let merged = merged.remove(0);
diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs 
b/datafusion/core/src/physical_plan/sorts/stream.rs
new file mode 100644
index 0000000000..1bc046042e
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/stream.rs
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::sorts::cursor::SortKeyCursor;
+use crate::physical_plan::SendableRecordBatchStream;
+use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr};
+use arrow::datatypes::Schema;
+use arrow::record_batch::RecordBatch;
+use arrow::row::{RowConverter, SortField};
+use futures::stream::{Fuse, StreamExt};
+use std::sync::Arc;
+use std::task::{ready, Context, Poll};
+
+/// A [`Stream`](futures::Stream) that has multiple partitions that can
+/// be polled separately but not concurrently
+///
+/// Used by sort preserving merge to decouple the cursor merging logic from
+/// the source of the cursors, the intention being to allow preserving
+/// any row encoding performed for intermediate sorts
+pub trait PartitionedStream: std::fmt::Debug + Send {
+    type Output;
+
+    /// Returns the number of partitions
+    fn partitions(&self) -> usize;
+
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Self::Output>>;
+}
+
+/// A newtype wrapper around a set of fused [`SendableRecordBatchStream`]
+/// that implements debug, and skips over empty [`RecordBatch`]
+struct FusedStreams(Vec<Fuse<SendableRecordBatchStream>>);
+
+impl std::fmt::Debug for FusedStreams {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("FusedStreams")
+            .field("num_streams", &self.0.len())
+            .finish()
+    }
+}
+
+impl FusedStreams {
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Result<RecordBatch>>> {
+        loop {
+            match ready!(self.0[stream_idx].poll_next_unpin(cx)) {
+                Some(Ok(b)) if b.num_rows() == 0 => continue,
+                r => return Poll::Ready(r),
+            }
+        }
+    }
+}
+
+/// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`]
+/// and computes [`SortKeyCursor`] based on the provided [`PhysicalSortExpr`]
+#[derive(Debug)]
+pub(crate) struct SortKeyCursorStream {
+    /// Converter to convert output of physical expressions
+    converter: RowConverter,
+    /// The physical expressions to sort by
+    column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+    /// Input streams
+    streams: FusedStreams,
+}
+
+impl SortKeyCursorStream {
+    pub(crate) fn try_new(
+        schema: &Schema,
+        expressions: &[PhysicalSortExpr],
+        streams: Vec<SendableRecordBatchStream>,
+    ) -> Result<Self> {
+        let sort_fields = expressions
+            .iter()
+            .map(|expr| {
+                let data_type = expr.expr.data_type(schema)?;
+                Ok(SortField::new_with_options(data_type, expr.options))
+            })
+            .collect::<Result<Vec<_>>>()?;
+
+        let streams = streams.into_iter().map(|s| s.fuse()).collect();
+        let converter = RowConverter::new(sort_fields)?;
+        Ok(Self {
+            converter,
+            column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
+            streams: FusedStreams(streams),
+        })
+    }
+
+    fn convert_batch(
+        &mut self,
+        batch: &RecordBatch,
+        stream_idx: usize,
+    ) -> Result<SortKeyCursor> {
+        let cols = self
+            .column_expressions
+            .iter()
+            .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
+            .collect::<Result<Vec<_>>>()?;
+
+        let rows = self.converter.convert_columns(&cols)?;
+        Ok(SortKeyCursor::new(stream_idx, rows))
+    }
+}
+
+impl PartitionedStream for SortKeyCursorStream {
+    type Output = Result<(SortKeyCursor, RecordBatch)>;
+
+    fn partitions(&self) -> usize {
+        self.streams.0.len()
+    }
+
+    fn poll_next(
+        &mut self,
+        cx: &mut Context<'_>,
+        stream_idx: usize,
+    ) -> Poll<Option<Self::Output>> {
+        Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
+            r.and_then(|batch| {
+                let cursor = self.convert_batch(&batch, stream_idx)?;
+                Ok((cursor, batch))
+            })
+        }))
+    }
+}


Reply via email to