This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new e41711b29e Generify SortPreservingMerge (#5882) (#5879) (#5886)
e41711b29e is described below
commit e41711b29e72f2d9a8a25c1ed00ccf3ed369f3bc
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Fri Apr 7 09:21:55 2023 +0100
Generify SortPreservingMerge (#5882) (#5879) (#5886)
* Generify SortPreservingMerge (#5882) (#5879)
* Review feedback
---
datafusion/core/src/physical_plan/sorts/builder.rs | 171 +++++++
datafusion/core/src/physical_plan/sorts/cursor.rs | 21 +
datafusion/core/src/physical_plan/sorts/merge.rs | 258 +++++++++++
datafusion/core/src/physical_plan/sorts/mod.rs | 24 +-
datafusion/core/src/physical_plan/sorts/sort.rs | 30 +-
.../physical_plan/sorts/sort_preserving_merge.rs | 489 +--------------------
datafusion/core/src/physical_plan/sorts/stream.rs | 145 ++++++
7 files changed, 632 insertions(+), 506 deletions(-)
diff --git a/datafusion/core/src/physical_plan/sorts/builder.rs
b/datafusion/core/src/physical_plan/sorts/builder.rs
new file mode 100644
index 0000000000..a1941963b6
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/builder.rs
@@ -0,0 +1,171 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::sorts::index::RowIndex;
+use arrow::array::{make_array, MutableArrayData};
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use std::collections::VecDeque;
+
+/// Provides an API to incrementally build a [`RecordBatch`] from partitioned
[`RecordBatch`]
+#[derive(Debug)]
+pub struct BatchBuilder {
+ /// The schema of the RecordBatches yielded by this stream
+ schema: SchemaRef,
+ /// For each input stream maintain a dequeue of RecordBatches
+ ///
+ /// Exhausted batches will be popped off the front once all
+ /// their rows have been yielded to the output
+ batches: Vec<VecDeque<RecordBatch>>,
+
+ /// The accumulated row indexes for the next record batch
+ indices: Vec<RowIndex>,
+}
+
+impl BatchBuilder {
+ /// Create a new [`BatchBuilder`] with the provided `stream_count` and
`batch_size`
+ pub fn new(schema: SchemaRef, stream_count: usize, batch_size: usize) ->
Self {
+ let batches = (0..stream_count).map(|_| VecDeque::new()).collect();
+
+ Self {
+ schema,
+ batches,
+ indices: Vec::with_capacity(batch_size),
+ }
+ }
+
+ /// Append a new batch in `stream_idx`
+ pub fn push_batch(&mut self, stream_idx: usize, batch: RecordBatch) {
+ self.batches[stream_idx].push_back(batch)
+ }
+
+ /// Push `row_idx` from the most recently appended batch in `stream_idx`
+ pub fn push_row(&mut self, stream_idx: usize, row_idx: usize) {
+ let batch_idx = self.batches[stream_idx].len() - 1;
+ self.indices.push(RowIndex {
+ stream_idx,
+ batch_idx,
+ row_idx,
+ });
+ }
+
+ /// Returns the number of in-progress rows in this [`BatchBuilder`]
+ pub fn len(&self) -> usize {
+ self.indices.len()
+ }
+
+ /// Returns `true` if this [`BatchBuilder`] contains no in-progress rows
+ pub fn is_empty(&self) -> bool {
+ self.indices.is_empty()
+ }
+
+ /// Returns the schema of this [`BatchBuilder`]
+ pub fn schema(&self) -> &SchemaRef {
+ &self.schema
+ }
+
+ /// Drains the in_progress row indexes, and builds a new RecordBatch from
them
+ ///
+ /// Will then drop any batches for which all rows have been yielded to the
output
+ ///
+ /// Returns `None` if no pending rows
+ pub fn build_record_batch(&mut self) -> Result<Option<RecordBatch>> {
+ if self.is_empty() {
+ return Ok(None);
+ }
+
+ // Mapping from stream index to the index of the first buffer from
that stream
+ let mut buffer_idx = 0;
+ let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
+
+ for batches in &self.batches {
+ stream_to_buffer_idx.push(buffer_idx);
+ buffer_idx += batches.len();
+ }
+
+ let columns = self
+ .schema
+ .fields()
+ .iter()
+ .enumerate()
+ .map(|(column_idx, field)| {
+ let arrays = self
+ .batches
+ .iter()
+ .flat_map(|batch| {
+ batch.iter().map(|batch|
batch.column(column_idx).data())
+ })
+ .collect();
+
+ let mut array_data = MutableArrayData::new(
+ arrays,
+ field.is_nullable(),
+ self.indices.len(),
+ );
+
+ let first = &self.indices[0];
+ let mut buffer_idx =
+ stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
+ let mut start_row_idx = first.row_idx;
+ let mut end_row_idx = start_row_idx + 1;
+
+ for row_index in self.indices.iter().skip(1) {
+ let next_buffer_idx =
+ stream_to_buffer_idx[row_index.stream_idx] +
row_index.batch_idx;
+
+ if next_buffer_idx == buffer_idx && row_index.row_idx ==
end_row_idx {
+ // subsequent row in same batch
+ end_row_idx += 1;
+ continue;
+ }
+
+ // emit current batch of rows for current buffer
+ array_data.extend(buffer_idx, start_row_idx, end_row_idx);
+
+ // start new batch of rows
+ buffer_idx = next_buffer_idx;
+ start_row_idx = row_index.row_idx;
+ end_row_idx = start_row_idx + 1;
+ }
+
+ // emit final batch of rows
+ array_data.extend(buffer_idx, start_row_idx, end_row_idx);
+ make_array(array_data.freeze())
+ })
+ .collect();
+
+ self.indices.clear();
+
+ // New cursors are only created once the previous cursor for the stream
+ // is finished. This means all remaining rows from all but the last
batch
+ // for each stream have been yielded to the newly created record batch
+ //
+ // Additionally as `in_progress` has been drained, there are no longer
+ // any RowIndex's reliant on the batch indexes
+ //
+ // We can therefore drop all but the last batch for each stream
+ for batches in &mut self.batches {
+ if batches.len() > 1 {
+ // Drain all but the last batch
+ batches.drain(0..(batches.len() - 1));
+ }
+ }
+
+ Ok(Some(RecordBatch::try_new(self.schema.clone(), columns)?))
+ }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/cursor.rs
b/datafusion/core/src/physical_plan/sorts/cursor.rs
index e52544cf50..8ab2acdda4 100644
--- a/datafusion/core/src/physical_plan/sorts/cursor.rs
+++ b/datafusion/core/src/physical_plan/sorts/cursor.rs
@@ -110,3 +110,24 @@ impl Ord for SortKeyCursor {
}
}
}
+
+/// A cursor into a sorted batch of rows
+pub trait Cursor: Ord {
+ /// Returns true if there are no more rows in this cursor
+ fn is_finished(&self) -> bool;
+
+ /// Advance the cursor, returning the previous row index
+ ///
+ /// Returns `None` if [`Self::is_finished`]
+ fn advance(&mut self) -> Option<usize>;
+}
+
+impl Cursor for SortKeyCursor {
+ fn is_finished(&self) -> bool {
+ self.is_finished()
+ }
+
+ fn advance(&mut self) -> Option<usize> {
+ (!self.is_finished()).then(|| self.advance())
+ }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/merge.rs
b/datafusion/core/src/physical_plan/sorts/merge.rs
new file mode 100644
index 0000000000..d6b43db82b
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/merge.rs
@@ -0,0 +1,258 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::metrics::MemTrackingMetrics;
+use crate::physical_plan::sorts::builder::BatchBuilder;
+use crate::physical_plan::sorts::cursor::Cursor;
+use crate::physical_plan::sorts::stream::{PartitionedStream,
SortKeyCursorStream};
+use crate::physical_plan::{
+ PhysicalSortExpr, RecordBatchStream, SendableRecordBatchStream,
+};
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use futures::Stream;
+use std::pin::Pin;
+use std::task::{ready, Context, Poll};
+
+/// Perform a streaming merge of [`SendableRecordBatchStream`]
+pub(crate) fn streaming_merge(
+ streams: Vec<SendableRecordBatchStream>,
+ schema: SchemaRef,
+ expressions: &[PhysicalSortExpr],
+ tracking_metrics: MemTrackingMetrics,
+ batch_size: usize,
+) -> Result<SendableRecordBatchStream> {
+ let streams = SortKeyCursorStream::try_new(schema.as_ref(), expressions,
streams)?;
+
+ Ok(Box::pin(SortPreservingMergeStream::new(
+ Box::new(streams),
+ schema,
+ tracking_metrics,
+ batch_size,
+ )))
+}
+
+/// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`]
+type CursorStream<C> = Box<dyn PartitionedStream<Output = Result<(C,
RecordBatch)>>>;
+
+#[derive(Debug)]
+struct SortPreservingMergeStream<C> {
+ in_progress: BatchBuilder,
+
+ /// The sorted input streams to merge together
+ streams: CursorStream<C>,
+
+ /// used to record execution metrics
+ tracking_metrics: MemTrackingMetrics,
+
+ /// If the stream has encountered an error
+ aborted: bool,
+
+ /// A loser tree that always produces the minimum cursor
+ ///
+ /// Node 0 stores the top winner, Nodes 1..num_streams store
+ /// the loser nodes
+ ///
+ /// This implements a "Tournament Tree" (aka Loser Tree) to keep
+ /// track of the current smallest element at the top. When the top
+ /// record is taken, the tree structure is not modified, and only
+ /// the path from bottom to top is visited, keeping the number of
+ /// comparisons close to the theoretical limit of `log(S)`.
+ ///
+ /// reference:
<https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
+ loser_tree: Vec<usize>,
+
+ /// If the most recently yielded overall winner has been replaced
+ /// within the loser tree. A value of `false` indicates that the
+ /// overall winner has been yielded but the loser tree has not
+ /// been updated
+ loser_tree_adjusted: bool,
+
+ /// target batch size
+ batch_size: usize,
+
+ /// Vector that holds cursors for each non-exhausted input partition
+ cursors: Vec<Option<C>>,
+}
+
+impl<C: Cursor> SortPreservingMergeStream<C> {
+ fn new(
+ streams: CursorStream<C>,
+ schema: SchemaRef,
+ tracking_metrics: MemTrackingMetrics,
+ batch_size: usize,
+ ) -> Self {
+ let stream_count = streams.partitions();
+
+ Self {
+ in_progress: BatchBuilder::new(schema, stream_count, batch_size),
+ streams,
+ tracking_metrics,
+ aborted: false,
+ cursors: (0..stream_count).map(|_| None).collect(),
+ loser_tree: vec![],
+ loser_tree_adjusted: false,
+ batch_size,
+ }
+ }
+
+ /// If the stream at the given index is not exhausted, and the last cursor
for the
+ /// stream is finished, poll the stream for the next RecordBatch and
create a new
+ /// cursor for the stream from the returned result
+ fn maybe_poll_stream(
+ &mut self,
+ cx: &mut Context<'_>,
+ idx: usize,
+ ) -> Poll<Result<()>> {
+ if self.cursors[idx]
+ .as_ref()
+ .map(|cursor| !cursor.is_finished())
+ .unwrap_or(false)
+ {
+ // Cursor is not finished - don't need a new RecordBatch yet
+ return Poll::Ready(Ok(()));
+ }
+
+ match futures::ready!(self.streams.poll_next(cx, idx)) {
+ None => Poll::Ready(Ok(())),
+ Some(Err(e)) => Poll::Ready(Err(e)),
+ Some(Ok((cursor, batch))) => {
+ self.cursors[idx] = Some(cursor);
+ self.in_progress.push_batch(idx, batch);
+ Poll::Ready(Ok(()))
+ }
+ }
+ }
+
+ fn poll_next_inner(
+ &mut self,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ if self.aborted {
+ return Poll::Ready(None);
+ }
+ // try to initialize the loser tree
+ if self.loser_tree.is_empty() {
+ // Ensure all non-exhausted streams have a cursor from which
+ // rows can be pulled
+ for i in 0..self.streams.partitions() {
+ if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
+ self.aborted = true;
+ return Poll::Ready(Some(Err(e)));
+ }
+ }
+ self.init_loser_tree();
+ }
+
+ // NB timer records time taken on drop, so there are no
+ // calls to `timer.done()` below.
+ let elapsed_compute = self.tracking_metrics.elapsed_compute().clone();
+ let _timer = elapsed_compute.timer();
+
+ loop {
+ // Adjust the loser tree if necessary, returning control if needed
+ if !self.loser_tree_adjusted {
+ let winner = self.loser_tree[0];
+ if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
+ self.aborted = true;
+ return Poll::Ready(Some(Err(e)));
+ }
+ self.update_loser_tree();
+ }
+
+ let stream_idx = self.loser_tree[0];
+ let cursor = self.cursors[stream_idx].as_mut();
+ if let Some(row_idx) = cursor.and_then(Cursor::advance) {
+ self.loser_tree_adjusted = false;
+ self.in_progress.push_row(stream_idx, row_idx);
+ if self.in_progress.len() < self.batch_size {
+ continue;
+ }
+ }
+
+ return
Poll::Ready(self.in_progress.build_record_batch().transpose());
+ }
+ }
+
+ /// Returns `true` if the cursor at index `a` is greater than at index `b`
+ #[inline]
+ fn is_gt(&self, a: usize, b: usize) -> bool {
+ match (&self.cursors[a], &self.cursors[b]) {
+ (None, _) => true,
+ (_, None) => false,
+ (Some(a), Some(b)) => b < a,
+ }
+ }
+
+ /// Attempts to initialize the loser tree with one value from each
+ /// non exhausted input, if possible
+ fn init_loser_tree(&mut self) {
+ // Init loser tree
+ self.loser_tree = vec![usize::MAX; self.cursors.len()];
+ for i in 0..self.cursors.len() {
+ let mut winner = i;
+ let mut cmp_node = (self.cursors.len() + i) / 2;
+ while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
+ let challenger = self.loser_tree[cmp_node];
+ if self.is_gt(winner, challenger) {
+ self.loser_tree[cmp_node] = winner;
+ winner = challenger;
+ }
+
+ cmp_node /= 2;
+ }
+ self.loser_tree[cmp_node] = winner;
+ }
+ self.loser_tree_adjusted = true;
+ }
+
+ /// Attempts to update the loser tree, following winner replacement, if
possible
+ fn update_loser_tree(&mut self) {
+ let mut winner = self.loser_tree[0];
+ // Replace overall winner by walking tree of losers
+ let mut cmp_node = (self.cursors.len() + winner) / 2;
+ while cmp_node != 0 {
+ let challenger = self.loser_tree[cmp_node];
+ if self.is_gt(winner, challenger) {
+ self.loser_tree[cmp_node] = winner;
+ winner = challenger;
+ }
+ cmp_node /= 2;
+ }
+ self.loser_tree[0] = winner;
+ self.loser_tree_adjusted = true;
+ }
+}
+
+impl<C: Cursor + Unpin> Stream for SortPreservingMergeStream<C> {
+ type Item = Result<RecordBatch>;
+
+ fn poll_next(
+ mut self: Pin<&mut Self>,
+ cx: &mut Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ let poll = self.poll_next_inner(cx);
+ self.tracking_metrics.record_poll(poll)
+ }
+}
+
+impl<C: Cursor + Unpin> RecordBatchStream for SortPreservingMergeStream<C> {
+ fn schema(&self) -> SchemaRef {
+ self.in_progress.schema().clone()
+ }
+}
diff --git a/datafusion/core/src/physical_plan/sorts/mod.rs
b/datafusion/core/src/physical_plan/sorts/mod.rs
index db6ab5c604..cd5dae27dc 100644
--- a/datafusion/core/src/physical_plan/sorts/mod.rs
+++ b/datafusion/core/src/physical_plan/sorts/mod.rs
@@ -17,30 +17,14 @@
//! Sort functionalities
-use crate::physical_plan::SendableRecordBatchStream;
-use std::fmt::{Debug, Formatter};
-
+mod builder;
mod cursor;
mod index;
+mod merge;
pub mod sort;
pub mod sort_preserving_merge;
+mod stream;
pub use cursor::SortKeyCursor;
pub use index::RowIndex;
-
-pub(crate) struct SortedStream {
- stream: SendableRecordBatchStream,
- mem_used: usize,
-}
-
-impl Debug for SortedStream {
- fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
- write!(f, "InMemSorterStream")
- }
-}
-
-impl SortedStream {
- pub(crate) fn new(stream: SendableRecordBatchStream, mem_used: usize) ->
Self {
- Self { stream, mem_used }
- }
-}
+pub(crate) use merge::streaming_merge;
diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs
b/datafusion/core/src/physical_plan/sorts/sort.rs
index c3fc06206c..1eceb02c2d 100644
--- a/datafusion/core/src/physical_plan/sorts/sort.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort.rs
@@ -30,8 +30,7 @@ use crate::physical_plan::expressions::PhysicalSortExpr;
use crate::physical_plan::metrics::{
BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet,
};
-use
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
-use crate::physical_plan::sorts::SortedStream;
+use crate::physical_plan::sorts::merge::streaming_merge;
use crate::physical_plan::stream::{RecordBatchReceiverStream,
RecordBatchStreamAdapter};
use crate::physical_plan::{
DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan,
Partitioning,
@@ -169,37 +168,40 @@ impl ExternalSorter {
let batch_size = self.session_config.batch_size();
if self.spilled_before() {
- let tracking_metrics = self
+ let intermediate_metrics = self
.metrics_set
.new_intermediate_tracking(self.partition_id,
&self.runtime.memory_pool);
- let mut streams: Vec<SortedStream> = vec![];
+ let mut merge_metrics = self
+ .metrics_set
+ .new_final_tracking(self.partition_id,
&self.runtime.memory_pool);
+
+ let mut streams = vec![];
if !self.in_mem_batches.is_empty() {
let in_mem_stream = in_mem_partial_sort(
&mut self.in_mem_batches,
self.schema.clone(),
&self.expr,
batch_size,
- tracking_metrics,
+ intermediate_metrics,
self.fetch,
)?;
- let prev_used = self.reservation.free();
- streams.push(SortedStream::new(in_mem_stream, prev_used));
+ // TODO: More accurate, dynamic memory accounting (#5885)
+ merge_metrics.init_mem_used(self.reservation.free());
+ streams.push(in_mem_stream);
}
for spill in self.spills.drain(..) {
let stream = read_spill_as_stream(spill, self.schema.clone())?;
- streams.push(SortedStream::new(stream, 0));
+ streams.push(stream);
}
- let tracking_metrics = self
- .metrics_set
- .new_final_tracking(self.partition_id,
&self.runtime.memory_pool);
- Ok(Box::pin(SortPreservingMergeStream::new_from_streams(
+
+ streaming_merge(
streams,
self.schema.clone(),
&self.expr,
- tracking_metrics,
+ merge_metrics,
self.session_config.batch_size(),
- )?))
+ )
} else if !self.in_mem_batches.is_empty() {
let tracking_metrics = self
.metrics_set
diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
index 14204ef3b4..98cfb24c92 100644
--- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -18,19 +18,9 @@
//! Defines the sort preserving merge plan
use std::any::Any;
-use std::collections::VecDeque;
-use std::pin::Pin;
use std::sync::Arc;
-use std::task::{Context, Poll};
-use arrow::row::{RowConverter, SortField};
-use arrow::{
- array::{make_array as make_arrow_array, MutableArrayData},
- datatypes::SchemaRef,
- record_batch::RecordBatch,
-};
-use futures::stream::{Fuse, FusedStream};
-use futures::{ready, Stream, StreamExt};
+use arrow::datatypes::SchemaRef;
use log::debug;
use tokio::sync::mpsc;
@@ -39,12 +29,11 @@ use crate::execution::context::TaskContext;
use crate::physical_plan::metrics::{
ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
};
-use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream};
+use crate::physical_plan::sorts::streaming_merge;
use crate::physical_plan::stream::RecordBatchReceiverStream;
use crate::physical_plan::{
common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType,
- Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream,
- SendableRecordBatchStream, Statistics,
+ Distribution, ExecutionPlan, Partitioning, SendableRecordBatchStream,
Statistics,
};
use datafusion_physical_expr::{
make_sort_requirements_from_exprs, EquivalenceProperties,
PhysicalSortRequirement,
@@ -206,34 +195,27 @@ impl ExecutionPlan for SortPreservingMergeExec {
context.clone(),
);
- SortedStream::new(
- RecordBatchReceiverStream::create(
- &schema,
- receiver,
- join_handle,
- ),
- 0,
+ RecordBatchReceiverStream::create(
+ &schema,
+ receiver,
+ join_handle,
)
})
.collect(),
Err(_) => (0..input_partitions)
- .map(|partition| {
- let stream =
- self.input.execute(partition,
context.clone())?;
- Ok(SortedStream::new(stream, 0))
- })
+ .map(|partition| self.input.execute(partition,
context.clone()))
.collect::<Result<_>>()?,
};
debug!("Done setting up sender-receiver for
SortPreservingMergeExec::execute");
- let result =
Box::pin(SortPreservingMergeStream::new_from_streams(
+ let result = streaming_merge(
receivers,
schema,
&self.expr,
tracking_metrics,
context.session_config().batch_size(),
- )?);
+ )?;
debug!("Got stream result from
SortPreservingMergeStream::new_from_receivers");
@@ -264,445 +246,6 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}
-struct MergingStreams {
- /// The sorted input streams to merge together
- streams: Vec<Fuse<SendableRecordBatchStream>>,
- /// number of streams
- num_streams: usize,
-}
-
-impl std::fmt::Debug for MergingStreams {
- fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
- f.debug_struct("MergingStreams")
- .field("num_streams", &self.num_streams)
- .finish()
- }
-}
-
-impl MergingStreams {
- fn new(input_streams: Vec<Fuse<SendableRecordBatchStream>>) -> Self {
- Self {
- num_streams: input_streams.len(),
- streams: input_streams,
- }
- }
-
- fn num_streams(&self) -> usize {
- self.num_streams
- }
-}
-
-#[derive(Debug)]
-pub(crate) struct SortPreservingMergeStream {
- /// The schema of the RecordBatches yielded by this stream
- schema: SchemaRef,
-
- /// The sorted input streams to merge together
- streams: MergingStreams,
-
- /// For each input stream maintain a dequeue of RecordBatches
- ///
- /// Exhausted batches will be popped off the front once all
- /// their rows have been yielded to the output
- batches: Vec<VecDeque<RecordBatch>>,
-
- /// The accumulated row indexes for the next record batch
- in_progress: Vec<RowIndex>,
-
- /// The physical expressions to sort by
- column_expressions: Vec<Arc<dyn PhysicalExpr>>,
-
- /// used to record execution metrics
- tracking_metrics: MemTrackingMetrics,
-
- /// If the stream has encountered an error
- aborted: bool,
-
- /// Vector that holds all [`SortKeyCursor`]s
- cursors: Vec<Option<SortKeyCursor>>,
-
- /// A loser tree that always produces the minimum cursor
- ///
- /// Node 0 stores the top winner, Nodes 1..num_streams store
- /// the loser nodes
- ///
- /// This implements a "Tournament Tree" (aka Loser Tree) to keep
- /// track of the current smallest element at the top. When the top
- /// record is taken, the tree structure is not modified, and only
- /// the path from bottom to top is visited, keeping the number of
- /// comparisons close to the theoretical limit of `log(S)`.
- ///
- /// reference:
<https://en.wikipedia.org/wiki/K-way_merge_algorithm#Tournament_Tree>
- loser_tree: Vec<usize>,
-
- /// If the most recently yielded overall winner has been replaced
- /// within the loser tree. A value of `false` indicates that the
- /// overall winner has been yielded but the loser tree has not
- /// been updated
- loser_tree_adjusted: bool,
-
- /// target batch size
- batch_size: usize,
-
- /// row converter
- row_converter: RowConverter,
-}
-
-impl SortPreservingMergeStream {
- pub(crate) fn new_from_streams(
- streams: Vec<SortedStream>,
- schema: SchemaRef,
- expressions: &[PhysicalSortExpr],
- mut tracking_metrics: MemTrackingMetrics,
- batch_size: usize,
- ) -> Result<Self> {
- let stream_count = streams.len();
- let batches = (0..stream_count).map(|_| VecDeque::new()).collect();
- tracking_metrics.init_mem_used(streams.iter().map(|s|
s.mem_used).sum());
- let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect();
-
- let sort_fields = expressions
- .iter()
- .map(|expr| {
- let data_type = expr.expr.data_type(&schema)?;
- Ok(SortField::new_with_options(data_type, expr.options))
- })
- .collect::<Result<Vec<_>>>()?;
- let row_converter = RowConverter::new(sort_fields)?;
-
- Ok(Self {
- schema,
- batches,
- streams: MergingStreams::new(wrappers),
- column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
- tracking_metrics,
- aborted: false,
- in_progress: vec![],
- cursors: (0..stream_count).map(|_| None).collect(),
- loser_tree: Vec::with_capacity(stream_count),
- loser_tree_adjusted: false,
- batch_size,
- row_converter,
- })
- }
-
- /// If the stream at the given index is not exhausted, and the last cursor
for the
- /// stream is finished, poll the stream for the next RecordBatch and
create a new
- /// cursor for the stream from the returned result
- fn maybe_poll_stream(
- &mut self,
- cx: &mut Context<'_>,
- idx: usize,
- ) -> Poll<Result<()>> {
- if self.cursors[idx]
- .as_ref()
- .map(|cursor| !cursor.is_finished())
- .unwrap_or(false)
- {
- // Cursor is not finished - don't need a new RecordBatch yet
- return Poll::Ready(Ok(()));
- }
- let mut empty_batch = false;
- {
- let stream = &mut self.streams.streams[idx];
- if stream.is_terminated() {
- return Poll::Ready(Ok(()));
- }
-
- // Fetch a new input record and create a cursor from it
- match futures::ready!(stream.poll_next_unpin(cx)) {
- None => return Poll::Ready(Ok(())),
- Some(Err(e)) => {
- return Poll::Ready(Err(e));
- }
- Some(Ok(batch)) => {
- if batch.num_rows() > 0 {
- let cols = self
- .column_expressions
- .iter()
- .map(|expr| {
-
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows()))
- })
- .collect::<Result<Vec<_>>>()?;
-
- let rows = match
self.row_converter.convert_columns(&cols) {
- Ok(rows) => rows,
- Err(e) => {
- return
Poll::Ready(Err(DataFusionError::ArrowError(e)));
- }
- };
-
- self.cursors[idx] = Some(SortKeyCursor::new(idx,
rows));
- self.batches[idx].push_back(batch)
- } else {
- empty_batch = true;
- }
- }
- }
- }
-
- if empty_batch {
- self.maybe_poll_stream(cx, idx)
- } else {
- Poll::Ready(Ok(()))
- }
- }
-
- /// Drains the in_progress row indexes, and builds a new RecordBatch from
them
- ///
- /// Will then drop any batches for which all rows have been yielded to the
output
- fn build_record_batch(&mut self) -> Result<RecordBatch> {
- // Mapping from stream index to the index of the first buffer from
that stream
- let mut buffer_idx = 0;
- let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
-
- for batches in &self.batches {
- stream_to_buffer_idx.push(buffer_idx);
- buffer_idx += batches.len();
- }
-
- let columns = self
- .schema
- .fields()
- .iter()
- .enumerate()
- .map(|(column_idx, field)| {
- let arrays = self
- .batches
- .iter()
- .flat_map(|batch| {
- batch.iter().map(|batch|
batch.column(column_idx).data())
- })
- .collect();
-
- let mut array_data = MutableArrayData::new(
- arrays,
- field.is_nullable(),
- self.in_progress.len(),
- );
-
- if self.in_progress.is_empty() {
- return make_arrow_array(array_data.freeze());
- }
-
- let first = &self.in_progress[0];
- let mut buffer_idx =
- stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
- let mut start_row_idx = first.row_idx;
- let mut end_row_idx = start_row_idx + 1;
-
- for row_index in self.in_progress.iter().skip(1) {
- let next_buffer_idx =
- stream_to_buffer_idx[row_index.stream_idx] +
row_index.batch_idx;
-
- if next_buffer_idx == buffer_idx && row_index.row_idx ==
end_row_idx {
- // subsequent row in same batch
- end_row_idx += 1;
- continue;
- }
-
- // emit current batch of rows for current buffer
- array_data.extend(buffer_idx, start_row_idx, end_row_idx);
-
- // start new batch of rows
- buffer_idx = next_buffer_idx;
- start_row_idx = row_index.row_idx;
- end_row_idx = start_row_idx + 1;
- }
-
- // emit final batch of rows
- array_data.extend(buffer_idx, start_row_idx, end_row_idx);
- make_arrow_array(array_data.freeze())
- })
- .collect();
-
- self.in_progress.clear();
-
- // New cursors are only created once the previous cursor for the stream
- // is finished. This means all remaining rows from all but the last
batch
- // for each stream have been yielded to the newly created record batch
- //
- // Additionally as `in_progress` has been drained, there are no longer
- // any RowIndex's reliant on the batch indexes
- //
- // We can therefore drop all but the last batch for each stream
- for batches in &mut self.batches {
- if batches.len() > 1 {
- // Drain all but the last batch
- batches.drain(0..(batches.len() - 1));
- }
- }
-
- RecordBatch::try_new(self.schema.clone(), columns).map_err(Into::into)
- }
-}
-
-impl Stream for SortPreservingMergeStream {
- type Item = Result<RecordBatch>;
-
- fn poll_next(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- let poll = self.poll_next_inner(cx);
- self.tracking_metrics.record_poll(poll)
- }
-}
-
-impl SortPreservingMergeStream {
- #[inline]
- fn poll_next_inner(
- self: &mut Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Option<Result<RecordBatch>>> {
- if self.aborted {
- return Poll::Ready(None);
- }
- // try to initialize the loser tree
- if let Err(e) = ready!(self.init_loser_tree(cx)) {
- return Poll::Ready(Some(Err(e)));
- }
-
- // NB timer records time taken on drop, so there are no
- // calls to `timer.done()` below.
- let elapsed_compute = self.tracking_metrics.elapsed_compute().clone();
- let _timer = elapsed_compute.timer();
-
- loop {
- // Adjust the loser tree if necessary, returning control if needed
- if let Err(e) = ready!(self.update_loser_tree(cx)) {
- return Poll::Ready(Some(Err(e)));
- }
-
- let min_cursor_idx = self.loser_tree[0];
- let next = self.cursors[min_cursor_idx]
- .as_mut()
- .filter(|cursor| !cursor.is_finished())
- .map(|cursor| (cursor.stream_idx(), cursor.advance()));
-
- if let Some((stream_idx, row_idx)) = next {
- self.loser_tree_adjusted = false;
- let batch_idx = self.batches[stream_idx].len() - 1;
- self.in_progress.push(RowIndex {
- stream_idx,
- batch_idx,
- row_idx,
- });
- if self.in_progress.len() == self.batch_size {
- return Poll::Ready(Some(self.build_record_batch()));
- }
- } else if !self.in_progress.is_empty() {
- return Poll::Ready(Some(self.build_record_batch()));
- } else {
- return Poll::Ready(None);
- }
- }
- }
-
- /// Attempts to initialize the loser tree with one value from each
- /// non exhausted input, if possible.
- ///
- /// Returns
- /// * Poll::Pending when more data is needed
- /// * Poll::Ready(Ok()) on success
- /// * Poll::Ready(Err..) if any of the inputs errored
- #[inline]
- fn init_loser_tree(
- self: &mut Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<()>> {
- let num_streams = self.streams.num_streams();
-
- if !self.loser_tree.is_empty() {
- return Poll::Ready(Ok(()));
- }
-
- // Ensure all non-exhausted streams have a cursor from which
- // rows can be pulled
- for i in 0..num_streams {
- if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) {
- self.aborted = true;
- return Poll::Ready(Err(e));
- }
- }
-
- // Init loser tree
- self.loser_tree.resize(num_streams, usize::MAX);
- for i in 0..num_streams {
- let mut winner = i;
- let mut cmp_node = (num_streams + i) / 2;
- while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX {
- let challenger = self.loser_tree[cmp_node];
- let challenger_win =
- match (&self.cursors[winner], &self.cursors[challenger]) {
- (None, _) => true,
- (_, None) => false,
- (Some(winner), Some(challenger)) => challenger <
winner,
- };
-
- if challenger_win {
- self.loser_tree[cmp_node] = winner;
- winner = challenger;
- }
-
- cmp_node /= 2;
- }
- self.loser_tree[cmp_node] = winner;
- }
- self.loser_tree_adjusted = true;
- Poll::Ready(Ok(()))
- }
-
- /// Attempts to updated the loser tree, if possible
- ///
- /// Returns
- /// * Poll::Pending when the winning unput was not ready
- /// * Poll::Ready(Ok()) on success
- /// * Poll::Ready(Err..) if any of the winning input erroed
- #[inline]
- fn update_loser_tree(
- self: &mut Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<()>> {
- if self.loser_tree_adjusted {
- return Poll::Ready(Ok(()));
- }
-
- let num_streams = self.streams.num_streams();
- let mut winner = self.loser_tree[0];
- if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) {
- self.aborted = true;
- return Poll::Ready(Err(e));
- }
-
- // Replace overall winner by walking tree of losers
- let mut cmp_node = (num_streams + winner) / 2;
- while cmp_node != 0 {
- let challenger = self.loser_tree[cmp_node];
- let challenger_win = match (&self.cursors[winner],
&self.cursors[challenger])
- {
- (None, _) => true,
- (_, None) => false,
- (Some(winner), Some(challenger)) => challenger < winner,
- };
- if challenger_win {
- self.loser_tree[cmp_node] = winner;
- winner = challenger;
- }
- cmp_node /= 2;
- }
- self.loser_tree[0] = winner;
- self.loser_tree_adjusted = true;
- Poll::Ready(Ok(()))
- }
-}
-
-impl RecordBatchStream for SortPreservingMergeStream {
- fn schema(&self) -> SchemaRef {
- self.schema.clone()
- }
-}
-
#[cfg(test)]
mod tests {
use std::iter::FromIterator;
@@ -710,6 +253,7 @@ mod tests {
use arrow::array::ArrayRef;
use arrow::compute::SortOptions;
use arrow::datatypes::{DataType, Field, Schema};
+ use arrow::record_batch::RecordBatch;
use futures::FutureExt;
use tokio_stream::StreamExt;
@@ -1284,9 +828,10 @@ mod tests {
}
});
- streams.push(SortedStream::new(
- RecordBatchReceiverStream::create(&schema, receiver,
join_handle),
- 0,
+ streams.push(RecordBatchReceiverStream::create(
+ &schema,
+ receiver,
+ join_handle,
));
}
@@ -1294,7 +839,7 @@ mod tests {
let tracking_metrics =
MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0);
- let merge_stream = SortPreservingMergeStream::new_from_streams(
+ let merge_stream = streaming_merge(
streams,
batches.schema(),
sort.as_slice(),
@@ -1303,7 +848,7 @@ mod tests {
)
.unwrap();
- let mut merged =
common::collect(Box::pin(merge_stream)).await.unwrap();
+ let mut merged = common::collect(merge_stream).await.unwrap();
assert_eq!(merged.len(), 1);
let merged = merged.remove(0);
diff --git a/datafusion/core/src/physical_plan/sorts/stream.rs
b/datafusion/core/src/physical_plan/sorts/stream.rs
new file mode 100644
index 0000000000..1bc046042e
--- /dev/null
+++ b/datafusion/core/src/physical_plan/sorts/stream.rs
@@ -0,0 +1,145 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::common::Result;
+use crate::physical_plan::sorts::cursor::SortKeyCursor;
+use crate::physical_plan::SendableRecordBatchStream;
+use crate::physical_plan::{PhysicalExpr, PhysicalSortExpr};
+use arrow::datatypes::Schema;
+use arrow::record_batch::RecordBatch;
+use arrow::row::{RowConverter, SortField};
+use futures::stream::{Fuse, StreamExt};
+use std::sync::Arc;
+use std::task::{ready, Context, Poll};
+
+/// A [`Stream`](futures::Stream) that has multiple partitions that can
+/// be polled separately but not concurrently
+///
+/// Used by sort preserving merge to decouple the cursor merging logic from
+/// the source of the cursors, the intention being to allow preserving
+/// any row encoding performed for intermediate sorts
+pub trait PartitionedStream: std::fmt::Debug + Send {
+ type Output;
+
+ /// Returns the number of partitions
+ fn partitions(&self) -> usize;
+
+ fn poll_next(
+ &mut self,
+ cx: &mut Context<'_>,
+ stream_idx: usize,
+ ) -> Poll<Option<Self::Output>>;
+}
+
+/// A newtype wrapper around a set of fused [`SendableRecordBatchStream`]
+/// that implements debug, and skips over empty [`RecordBatch`]
+struct FusedStreams(Vec<Fuse<SendableRecordBatchStream>>);
+
+impl std::fmt::Debug for FusedStreams {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("FusedStreams")
+ .field("num_streams", &self.0.len())
+ .finish()
+ }
+}
+
+impl FusedStreams {
+ fn poll_next(
+ &mut self,
+ cx: &mut Context<'_>,
+ stream_idx: usize,
+ ) -> Poll<Option<Result<RecordBatch>>> {
+ loop {
+ match ready!(self.0[stream_idx].poll_next_unpin(cx)) {
+ Some(Ok(b)) if b.num_rows() == 0 => continue,
+ r => return Poll::Ready(r),
+ }
+ }
+ }
+}
+
+/// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`]
+/// and computes [`SortKeyCursor`] based on the provided [`PhysicalSortExpr`]
+#[derive(Debug)]
+pub(crate) struct SortKeyCursorStream {
+ /// Converter to convert output of physical expressions
+ converter: RowConverter,
+ /// The physical expressions to sort by
+ column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+ /// Input streams
+ streams: FusedStreams,
+}
+
+impl SortKeyCursorStream {
+ pub(crate) fn try_new(
+ schema: &Schema,
+ expressions: &[PhysicalSortExpr],
+ streams: Vec<SendableRecordBatchStream>,
+ ) -> Result<Self> {
+ let sort_fields = expressions
+ .iter()
+ .map(|expr| {
+ let data_type = expr.expr.data_type(schema)?;
+ Ok(SortField::new_with_options(data_type, expr.options))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let streams = streams.into_iter().map(|s| s.fuse()).collect();
+ let converter = RowConverter::new(sort_fields)?;
+ Ok(Self {
+ converter,
+ column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
+ streams: FusedStreams(streams),
+ })
+ }
+
+ fn convert_batch(
+ &mut self,
+ batch: &RecordBatch,
+ stream_idx: usize,
+ ) -> Result<SortKeyCursor> {
+ let cols = self
+ .column_expressions
+ .iter()
+ .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
+ .collect::<Result<Vec<_>>>()?;
+
+ let rows = self.converter.convert_columns(&cols)?;
+ Ok(SortKeyCursor::new(stream_idx, rows))
+ }
+}
+
+impl PartitionedStream for SortKeyCursorStream {
+ type Output = Result<(SortKeyCursor, RecordBatch)>;
+
+ fn partitions(&self) -> usize {
+ self.streams.0.len()
+ }
+
+ fn poll_next(
+ &mut self,
+ cx: &mut Context<'_>,
+ stream_idx: usize,
+ ) -> Poll<Option<Self::Output>> {
+ Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
+ r.and_then(|batch| {
+ let cursor = self.convert_batch(&batch, stream_idx)?;
+ Ok((cursor, batch))
+ })
+ }))
+ }
+}