This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 62edddb Optimize `SortPreservingMergeStream` to avoid `SortKeyCursor`
sharing (#1624)
62edddb is described below
commit 62edddb30d84f9593ba4c31a2147e1ae1830ae86
Author: Yijie Shen <[email protected]>
AuthorDate: Sat Jan 22 20:09:08 2022 +0800
Optimize `SortPreservingMergeStream` to avoid `SortKeyCursor` sharing
(#1624)
* Change SPMS to use heap sort, use SPMS instead of in-mem-sort as well
* Incorporate metrics, external_sort pass all sort tests
* Remove the original sort, substitute with external sort
* Fix different batch_size setting in SPMS test
* Change to use combine and sort for in memory N-way merge
* Resolve comments on async and doc
* Update sort to avoid deadlock during spilling
* Fix spill hanging
* Optimize SPMS, cursor no more shared
* Add doc for
* Resolve comments
---
datafusion/src/physical_plan/sorts/mod.rs | 53 ++++-----
.../physical_plan/sorts/sort_preserving_merge.rs | 128 +++++++++++----------
2 files changed, 90 insertions(+), 91 deletions(-)
diff --git a/datafusion/src/physical_plan/sorts/mod.rs
b/datafusion/src/physical_plan/sorts/mod.rs
index 1bb880f..b49b583 100644
--- a/datafusion/src/physical_plan/sorts/mod.rs
+++ b/datafusion/src/physical_plan/sorts/mod.rs
@@ -32,7 +32,6 @@ use std::borrow::BorrowMut;
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, RwLock};
use std::task::{Context, Poll};
@@ -51,12 +50,11 @@ pub mod sort_preserving_merge;
struct SortKeyCursor {
stream_idx: usize,
sort_columns: Vec<ArrayRef>,
- cur_row: AtomicUsize,
+ cur_row: usize,
num_rows: usize,
- // An index uniquely identifying the record batch scanned by this cursor.
- batch_idx: usize,
- batch: Arc<RecordBatch>,
+ // An id uniquely identifying the record batch scanned by this cursor.
+ batch_id: usize,
// A collection of comparators that compare rows in this cursor's batch to
// the cursors in other batches. Other batches are uniquely identified by
@@ -69,10 +67,9 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SortKeyCursor")
.field("sort_columns", &self.sort_columns)
- .field("cur_row", &self.cur_row())
+ .field("cur_row", &self.cur_row)
.field("num_rows", &self.num_rows)
- .field("batch_idx", &self.batch_idx)
- .field("batch", &self.batch)
+ .field("batch_id", &self.batch_id)
.field("batch_comparators", &"<FUNC>")
.finish()
}
@@ -81,39 +78,35 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
impl SortKeyCursor {
fn new(
stream_idx: usize,
- batch_idx: usize,
- batch: Arc<RecordBatch>,
+ batch_id: usize,
+ batch: &RecordBatch,
sort_key: &[Arc<dyn PhysicalExpr>],
sort_options: Arc<Vec<SortOptions>>,
) -> error::Result<Self> {
let sort_columns = sort_key
.iter()
- .map(|expr|
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
+ .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
.collect::<error::Result<_>>()?;
Ok(Self {
stream_idx,
- cur_row: AtomicUsize::new(0),
+ cur_row: 0,
num_rows: batch.num_rows(),
sort_columns,
- batch,
- batch_idx,
+ batch_id,
batch_comparators: RwLock::new(HashMap::new()),
sort_options,
})
}
fn is_finished(&self) -> bool {
- self.num_rows == self.cur_row()
+ self.num_rows == self.cur_row
}
- fn advance(&self) -> usize {
+ fn advance(&mut self) -> usize {
assert!(!self.is_finished());
- self.cur_row
- .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
- }
-
- fn cur_row(&self) -> usize {
- self.cur_row.load(std::sync::atomic::Ordering::SeqCst)
+ let t = self.cur_row;
+ self.cur_row += 1;
+ t
}
/// Compares the sort key pointed to by this instance's row cursor with
that of another
@@ -143,15 +136,15 @@ impl SortKeyCursor {
self.init_cmp_if_needed(other, &zipped)?;
let map = self.batch_comparators.read().unwrap();
- let cmp = map.get(&other.batch_idx).ok_or_else(|| {
+ let cmp = map.get(&other.batch_id).ok_or_else(|| {
DataFusionError::Execution(format!(
"Failed to find comparator for {} cmp {}",
- self.batch_idx, other.batch_idx
+ self.batch_id, other.batch_id
))
})?;
for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
- match (l.is_valid(self.cur_row()), r.is_valid(other.cur_row())) {
+ match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
(false, true) if sort_options.nulls_first => return
Ok(Ordering::Less),
(false, true) => return Ok(Ordering::Greater),
(true, false) if sort_options.nulls_first => {
@@ -159,7 +152,7 @@ impl SortKeyCursor {
}
(true, false) => return Ok(Ordering::Less),
(false, false) => {}
- (true, true) => match cmp[i](self.cur_row(), other.cur_row()) {
+ (true, true) => match cmp[i](self.cur_row, other.cur_row) {
Ordering::Equal => {}
o if sort_options.descending => return Ok(o.reverse()),
o => return Ok(o),
@@ -178,12 +171,12 @@ impl SortKeyCursor {
zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)],
) -> Result<()> {
let hm = self.batch_comparators.read().unwrap();
- if !hm.contains_key(&other.batch_idx) {
+ if !hm.contains_key(&other.batch_id) {
drop(hm);
let mut map = self.batch_comparators.write().unwrap();
let cmp = map
.borrow_mut()
- .entry(other.batch_idx)
+ .entry(other.batch_id)
.or_insert_with(||
Vec::with_capacity(other.sort_columns.len()));
for (i, ((l, r), _)) in zipped.iter().enumerate() {
@@ -224,8 +217,8 @@ impl PartialOrd for SortKeyCursor {
struct RowIndex {
/// The index of the stream
stream_idx: usize,
- /// The index of the cursor within the stream's VecDequeue.
- cursor_idx: usize,
+ /// The index of the batch within the stream's VecDequeue.
+ batch_idx: usize,
/// The row index
row_idx: usize,
}
diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index 189a9fb..d6a5787 100644
--- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -206,7 +206,9 @@ struct MergingStreams {
/// ConsumerId
id: MemoryConsumerId,
/// The sorted input streams to merge together
- pub(crate) streams: Mutex<Vec<StreamWrapper>>,
+ streams: Mutex<Vec<StreamWrapper>>,
+ /// number of streams
+ num_streams: usize,
/// Runtime
runtime: Arc<RuntimeEnv>,
}
@@ -220,17 +222,22 @@ impl Debug for MergingStreams {
}
impl MergingStreams {
- pub fn new(
+ fn new(
partition: usize,
input_streams: Vec<StreamWrapper>,
runtime: Arc<RuntimeEnv>,
) -> Self {
Self {
id: MemoryConsumerId::new(partition),
+ num_streams: input_streams.len(),
streams: Mutex::new(input_streams),
runtime,
}
}
+
+ fn num_streams(&self) -> usize {
+ self.num_streams
+ }
}
#[async_trait]
@@ -276,11 +283,15 @@ pub(crate) struct SortPreservingMergeStream {
/// Drop helper for tasks feeding the [`receivers`](Self::receivers)
_drop_helper: AbortOnDropMany<()>,
- /// For each input stream maintain a dequeue of SortKeyCursor
+ /// For each input stream maintain a dequeue of RecordBatches
///
- /// Exhausted cursors will be popped off the front once all
+ /// Exhausted batches will be popped off the front once all
/// their rows have been yielded to the output
- cursors: Vec<VecDeque<Arc<SortKeyCursor>>>,
+ batches: Vec<VecDeque<RecordBatch>>,
+
+ /// Maintain a flag for each stream denoting if the current cursor
+ /// has finished and needs to poll from the stream
+ cursor_finished: Vec<bool>,
/// The accumulated row indexes for the next record batch
in_progress: Vec<RowIndex>,
@@ -297,11 +308,11 @@ pub(crate) struct SortPreservingMergeStream {
/// If the stream has encountered an error
aborted: bool,
- /// An index to uniquely identify the input stream batch
- next_batch_index: usize,
+ /// An id to uniquely identify the input stream batch
+ next_batch_id: usize,
/// min heap for record comparison
- min_heap: BinaryHeap<Arc<SortKeyCursor>>,
+ min_heap: BinaryHeap<SortKeyCursor>,
/// runtime
runtime: Arc<RuntimeEnv>,
@@ -325,7 +336,7 @@ impl SortPreservingMergeStream {
runtime: Arc<RuntimeEnv>,
) -> Self {
let stream_count = receivers.len();
- let cursors = (0..stream_count)
+ let batches = (0..stream_count)
.into_iter()
.map(|_| VecDeque::new())
.collect();
@@ -335,7 +346,8 @@ impl SortPreservingMergeStream {
SortPreservingMergeStream {
schema,
- cursors,
+ batches,
+ cursor_finished: vec![true; stream_count],
streams,
_drop_helper,
column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
@@ -343,7 +355,7 @@ impl SortPreservingMergeStream {
baseline_metrics,
aborted: false,
in_progress: vec![],
- next_batch_index: 0,
+ next_batch_id: 0,
min_heap: BinaryHeap::with_capacity(stream_count),
runtime,
}
@@ -358,7 +370,7 @@ impl SortPreservingMergeStream {
runtime: Arc<RuntimeEnv>,
) -> Self {
let stream_count = streams.len();
- let cursors = (0..stream_count)
+ let batches = (0..stream_count)
.into_iter()
.map(|_| VecDeque::new())
.collect();
@@ -371,7 +383,8 @@ impl SortPreservingMergeStream {
Self {
schema,
- cursors,
+ batches,
+ cursor_finished: vec![true; stream_count],
streams,
_drop_helper: AbortOnDropMany(vec![]),
column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
@@ -379,7 +392,7 @@ impl SortPreservingMergeStream {
baseline_metrics,
aborted: false,
in_progress: vec![],
- next_batch_index: 0,
+ next_batch_id: 0,
min_heap: BinaryHeap::with_capacity(stream_count),
runtime,
}
@@ -393,13 +406,10 @@ impl SortPreservingMergeStream {
cx: &mut Context<'_>,
idx: usize,
) -> Poll<ArrowResult<()>> {
- if let Some(cursor) = &self.cursors[idx].back() {
- if !cursor.is_finished() {
- // Cursor is not finished - don't need a new RecordBatch yet
- return Poll::Ready(Ok(()));
- }
+ if !self.cursor_finished[idx] {
+ // Cursor is not finished - don't need a new RecordBatch yet
+ return Poll::Ready(Ok(()));
}
-
let mut streams = self.streams.streams.lock().unwrap();
let stream = &mut streams[idx];
@@ -414,25 +424,22 @@ impl SortPreservingMergeStream {
return Poll::Ready(Err(e));
}
Some(Ok(batch)) => {
- let cursor = Arc::new(
- match SortKeyCursor::new(
- idx,
- self.next_batch_index, // assign this batch an ID
- Arc::new(batch),
- &self.column_expressions,
- self.sort_options.clone(),
- ) {
- Ok(cursor) => cursor,
- Err(e) => {
- return Poll::Ready(Err(ArrowError::ExternalError(
- Box::new(e),
- )));
- }
- },
- );
- self.next_batch_index += 1;
- self.min_heap.push(cursor.clone());
- self.cursors[idx].push_back(cursor)
+ let cursor = match SortKeyCursor::new(
+ idx,
+ self.next_batch_id, // assign this batch an ID
+ &batch,
+ &self.column_expressions,
+ self.sort_options.clone(),
+ ) {
+ Ok(cursor) => cursor,
+ Err(e) => {
+ return
Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
+ }
+ };
+ self.next_batch_id += 1;
+ self.min_heap.push(cursor);
+ self.cursor_finished[idx] = false;
+ self.batches[idx].push_back(batch)
}
}
@@ -441,15 +448,15 @@ impl SortPreservingMergeStream {
/// Drains the in_progress row indexes, and builds a new RecordBatch from
them
///
- /// Will then drop any cursors for which all rows have been yielded to the
output
+ /// Will then drop any batches for which all rows have been yielded to the
output
fn build_record_batch(&mut self) -> ArrowResult<RecordBatch> {
// Mapping from stream index to the index of the first buffer from
that stream
let mut buffer_idx = 0;
- let mut stream_to_buffer_idx = Vec::with_capacity(self.cursors.len());
+ let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
- for cursors in &self.cursors {
+ for batches in &self.batches {
stream_to_buffer_idx.push(buffer_idx);
- buffer_idx += cursors.len();
+ buffer_idx += batches.len();
}
let columns = self
@@ -459,12 +466,10 @@ impl SortPreservingMergeStream {
.enumerate()
.map(|(column_idx, field)| {
let arrays = self
- .cursors
+ .batches
.iter()
- .flat_map(|cursor| {
- cursor
- .iter()
- .map(|cursor|
cursor.batch.column(column_idx).data())
+ .flat_map(|batch| {
+ batch.iter().map(|batch|
batch.column(column_idx).data())
})
.collect();
@@ -480,13 +485,13 @@ impl SortPreservingMergeStream {
let first = &self.in_progress[0];
let mut buffer_idx =
- stream_to_buffer_idx[first.stream_idx] + first.cursor_idx;
+ stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
let mut start_row_idx = first.row_idx;
let mut end_row_idx = start_row_idx + 1;
for row_index in self.in_progress.iter().skip(1) {
let next_buffer_idx =
- stream_to_buffer_idx[row_index.stream_idx] +
row_index.cursor_idx;
+ stream_to_buffer_idx[row_index.stream_idx] +
row_index.batch_idx;
if next_buffer_idx == buffer_idx && row_index.row_idx ==
end_row_idx {
// subsequent row in same batch
@@ -512,17 +517,17 @@ impl SortPreservingMergeStream {
self.in_progress.clear();
// New cursors are only created once the previous cursor for the stream
- // is finished. This means all remaining rows from all but the last
cursor
+ // is finished. This means all remaining rows from all but the last
batch
// for each stream have been yielded to the newly created record batch
//
// Additionally as `in_progress` has been drained, there are no longer
- // any RowIndex's reliant on the cursor indexes
+ // any RowIndex's reliant on the batch indexes
//
- // We can therefore drop all but the last cursor for each stream
- for cursors in &mut self.cursors {
- if cursors.len() > 1 {
- // Drain all but the last cursor
- cursors.drain(0..(cursors.len() - 1));
+ // We can therefore drop all but the last batch for each stream
+ for batches in &mut self.batches {
+ if batches.len() > 1 {
+ // Drain all but the last batch
+ batches.drain(0..(batches.len() - 1));
}
}
@@ -554,7 +559,7 @@ impl SortPreservingMergeStream {
// Ensure all non-exhausted streams have a cursor from which
// rows can be pulled
- for i in 0..self.cursors.len() {
+ for i in 0..self.streams.num_streams() {
match futures::ready!(self.maybe_poll_stream(cx, i)) {
Ok(_) => {}
Err(e) => {
@@ -571,9 +576,9 @@ impl SortPreservingMergeStream {
let _timer = elapsed_compute.timer();
match self.min_heap.pop() {
- Some(cursor) => {
+ Some(mut cursor) => {
let stream_idx = cursor.stream_idx;
- let cursor_idx = self.cursors[stream_idx].len() - 1;
+ let batch_idx = self.batches[stream_idx].len() - 1;
let row_idx = cursor.advance();
let mut cursor_finished = false;
@@ -582,11 +587,12 @@ impl SortPreservingMergeStream {
self.min_heap.push(cursor);
} else {
cursor_finished = true;
+ self.cursor_finished[stream_idx] = true;
}
self.in_progress.push(RowIndex {
stream_idx,
- cursor_idx,
+ batch_idx,
row_idx,
});