This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 62edddb  Optimize `SortPreservingMergeStream` to avoid `SortKeyCursor` 
sharing (#1624)
62edddb is described below

commit 62edddb30d84f9593ba4c31a2147e1ae1830ae86
Author: Yijie Shen <[email protected]>
AuthorDate: Sat Jan 22 20:09:08 2022 +0800

    Optimize `SortPreservingMergeStream` to avoid `SortKeyCursor` sharing 
(#1624)
    
    * Change SPMS to use heap sort, use SPMS instead of in-mem-sort as well
    
    * Incorporate metrics, external_sort pass all sort tests
    
    * Remove the original sort, substitute with external sort
    
    * Fix different batch_size setting in SPMS test
    
    * Change to use combine and sort for in memory N-way merge
    
    * Resolve comments on async and doc
    
    * Update sort to avoid deadlock during spilling
    
    * Fix spill hanging
    
    * Optimize SPMS, cursor no more shared
    
    * Add doc for
    
    * Resolve comments
---
 datafusion/src/physical_plan/sorts/mod.rs          |  53 ++++-----
 .../physical_plan/sorts/sort_preserving_merge.rs   | 128 +++++++++++----------
 2 files changed, 90 insertions(+), 91 deletions(-)

diff --git a/datafusion/src/physical_plan/sorts/mod.rs 
b/datafusion/src/physical_plan/sorts/mod.rs
index 1bb880f..b49b583 100644
--- a/datafusion/src/physical_plan/sorts/mod.rs
+++ b/datafusion/src/physical_plan/sorts/mod.rs
@@ -32,7 +32,6 @@ use std::borrow::BorrowMut;
 use std::cmp::Ordering;
 use std::fmt::{Debug, Formatter};
 use std::pin::Pin;
-use std::sync::atomic::AtomicUsize;
 use std::sync::{Arc, RwLock};
 use std::task::{Context, Poll};
 
@@ -51,12 +50,11 @@ pub mod sort_preserving_merge;
 struct SortKeyCursor {
     stream_idx: usize,
     sort_columns: Vec<ArrayRef>,
-    cur_row: AtomicUsize,
+    cur_row: usize,
     num_rows: usize,
 
-    // An index uniquely identifying the record batch scanned by this cursor.
-    batch_idx: usize,
-    batch: Arc<RecordBatch>,
+    // An id uniquely identifying the record batch scanned by this cursor.
+    batch_id: usize,
 
     // A collection of comparators that compare rows in this cursor's batch to
     // the cursors in other batches. Other batches are uniquely identified by
@@ -69,10 +67,9 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
         f.debug_struct("SortKeyCursor")
             .field("sort_columns", &self.sort_columns)
-            .field("cur_row", &self.cur_row())
+            .field("cur_row", &self.cur_row)
             .field("num_rows", &self.num_rows)
-            .field("batch_idx", &self.batch_idx)
-            .field("batch", &self.batch)
+            .field("batch_id", &self.batch_id)
             .field("batch_comparators", &"<FUNC>")
             .finish()
     }
@@ -81,39 +78,35 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
 impl SortKeyCursor {
     fn new(
         stream_idx: usize,
-        batch_idx: usize,
-        batch: Arc<RecordBatch>,
+        batch_id: usize,
+        batch: &RecordBatch,
         sort_key: &[Arc<dyn PhysicalExpr>],
         sort_options: Arc<Vec<SortOptions>>,
     ) -> error::Result<Self> {
         let sort_columns = sort_key
             .iter()
-            .map(|expr| 
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
+            .map(|expr| Ok(expr.evaluate(batch)?.into_array(batch.num_rows())))
             .collect::<error::Result<_>>()?;
         Ok(Self {
             stream_idx,
-            cur_row: AtomicUsize::new(0),
+            cur_row: 0,
             num_rows: batch.num_rows(),
             sort_columns,
-            batch,
-            batch_idx,
+            batch_id,
             batch_comparators: RwLock::new(HashMap::new()),
             sort_options,
         })
     }
 
     fn is_finished(&self) -> bool {
-        self.num_rows == self.cur_row()
+        self.num_rows == self.cur_row
     }
 
-    fn advance(&self) -> usize {
+    fn advance(&mut self) -> usize {
         assert!(!self.is_finished());
-        self.cur_row
-            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
-    }
-
-    fn cur_row(&self) -> usize {
-        self.cur_row.load(std::sync::atomic::Ordering::SeqCst)
+        let t = self.cur_row;
+        self.cur_row += 1;
+        t
     }
 
     /// Compares the sort key pointed to by this instance's row cursor with 
that of another
@@ -143,15 +136,15 @@ impl SortKeyCursor {
 
         self.init_cmp_if_needed(other, &zipped)?;
         let map = self.batch_comparators.read().unwrap();
-        let cmp = map.get(&other.batch_idx).ok_or_else(|| {
+        let cmp = map.get(&other.batch_id).ok_or_else(|| {
             DataFusionError::Execution(format!(
                 "Failed to find comparator for {} cmp {}",
-                self.batch_idx, other.batch_idx
+                self.batch_id, other.batch_id
             ))
         })?;
 
         for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
-            match (l.is_valid(self.cur_row()), r.is_valid(other.cur_row())) {
+            match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
                 (false, true) if sort_options.nulls_first => return 
Ok(Ordering::Less),
                 (false, true) => return Ok(Ordering::Greater),
                 (true, false) if sort_options.nulls_first => {
@@ -159,7 +152,7 @@ impl SortKeyCursor {
                 }
                 (true, false) => return Ok(Ordering::Less),
                 (false, false) => {}
-                (true, true) => match cmp[i](self.cur_row(), other.cur_row()) {
+                (true, true) => match cmp[i](self.cur_row, other.cur_row) {
                     Ordering::Equal => {}
                     o if sort_options.descending => return Ok(o.reverse()),
                     o => return Ok(o),
@@ -178,12 +171,12 @@ impl SortKeyCursor {
         zipped: &[((&ArrayRef, &ArrayRef), &SortOptions)],
     ) -> Result<()> {
         let hm = self.batch_comparators.read().unwrap();
-        if !hm.contains_key(&other.batch_idx) {
+        if !hm.contains_key(&other.batch_id) {
             drop(hm);
             let mut map = self.batch_comparators.write().unwrap();
             let cmp = map
                 .borrow_mut()
-                .entry(other.batch_idx)
+                .entry(other.batch_id)
                 .or_insert_with(|| 
Vec::with_capacity(other.sort_columns.len()));
 
             for (i, ((l, r), _)) in zipped.iter().enumerate() {
@@ -224,8 +217,8 @@ impl PartialOrd for SortKeyCursor {
 struct RowIndex {
     /// The index of the stream
     stream_idx: usize,
-    /// The index of the cursor within the stream's VecDequeue.
-    cursor_idx: usize,
+    /// The index of the batch within the stream's VecDequeue.
+    batch_idx: usize,
     /// The row index
     row_idx: usize,
 }
diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index 189a9fb..d6a5787 100644
--- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -206,7 +206,9 @@ struct MergingStreams {
     /// ConsumerId
     id: MemoryConsumerId,
     /// The sorted input streams to merge together
-    pub(crate) streams: Mutex<Vec<StreamWrapper>>,
+    streams: Mutex<Vec<StreamWrapper>>,
+    /// number of streams
+    num_streams: usize,
     /// Runtime
     runtime: Arc<RuntimeEnv>,
 }
@@ -220,17 +222,22 @@ impl Debug for MergingStreams {
 }
 
 impl MergingStreams {
-    pub fn new(
+    fn new(
         partition: usize,
         input_streams: Vec<StreamWrapper>,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
         Self {
             id: MemoryConsumerId::new(partition),
+            num_streams: input_streams.len(),
             streams: Mutex::new(input_streams),
             runtime,
         }
     }
+
+    fn num_streams(&self) -> usize {
+        self.num_streams
+    }
 }
 
 #[async_trait]
@@ -276,11 +283,15 @@ pub(crate) struct SortPreservingMergeStream {
     /// Drop helper for tasks feeding the [`receivers`](Self::receivers)
     _drop_helper: AbortOnDropMany<()>,
 
-    /// For each input stream maintain a dequeue of SortKeyCursor
+    /// For each input stream maintain a dequeue of RecordBatches
     ///
-    /// Exhausted cursors will be popped off the front once all
+    /// Exhausted batches will be popped off the front once all
     /// their rows have been yielded to the output
-    cursors: Vec<VecDeque<Arc<SortKeyCursor>>>,
+    batches: Vec<VecDeque<RecordBatch>>,
+
+    /// Maintain a flag for each stream denoting if the current cursor
+    /// has finished and needs to poll from the stream
+    cursor_finished: Vec<bool>,
 
     /// The accumulated row indexes for the next record batch
     in_progress: Vec<RowIndex>,
@@ -297,11 +308,11 @@ pub(crate) struct SortPreservingMergeStream {
     /// If the stream has encountered an error
     aborted: bool,
 
-    /// An index to uniquely identify the input stream batch
-    next_batch_index: usize,
+    /// An id to uniquely identify the input stream batch
+    next_batch_id: usize,
 
     /// min heap for record comparison
-    min_heap: BinaryHeap<Arc<SortKeyCursor>>,
+    min_heap: BinaryHeap<SortKeyCursor>,
 
     /// runtime
     runtime: Arc<RuntimeEnv>,
@@ -325,7 +336,7 @@ impl SortPreservingMergeStream {
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
         let stream_count = receivers.len();
-        let cursors = (0..stream_count)
+        let batches = (0..stream_count)
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
@@ -335,7 +346,8 @@ impl SortPreservingMergeStream {
 
         SortPreservingMergeStream {
             schema,
-            cursors,
+            batches,
+            cursor_finished: vec![true; stream_count],
             streams,
             _drop_helper,
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
@@ -343,7 +355,7 @@ impl SortPreservingMergeStream {
             baseline_metrics,
             aborted: false,
             in_progress: vec![],
-            next_batch_index: 0,
+            next_batch_id: 0,
             min_heap: BinaryHeap::with_capacity(stream_count),
             runtime,
         }
@@ -358,7 +370,7 @@ impl SortPreservingMergeStream {
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
         let stream_count = streams.len();
-        let cursors = (0..stream_count)
+        let batches = (0..stream_count)
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
@@ -371,7 +383,8 @@ impl SortPreservingMergeStream {
 
         Self {
             schema,
-            cursors,
+            batches,
+            cursor_finished: vec![true; stream_count],
             streams,
             _drop_helper: AbortOnDropMany(vec![]),
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
@@ -379,7 +392,7 @@ impl SortPreservingMergeStream {
             baseline_metrics,
             aborted: false,
             in_progress: vec![],
-            next_batch_index: 0,
+            next_batch_id: 0,
             min_heap: BinaryHeap::with_capacity(stream_count),
             runtime,
         }
@@ -393,13 +406,10 @@ impl SortPreservingMergeStream {
         cx: &mut Context<'_>,
         idx: usize,
     ) -> Poll<ArrowResult<()>> {
-        if let Some(cursor) = &self.cursors[idx].back() {
-            if !cursor.is_finished() {
-                // Cursor is not finished - don't need a new RecordBatch yet
-                return Poll::Ready(Ok(()));
-            }
+        if !self.cursor_finished[idx] {
+            // Cursor is not finished - don't need a new RecordBatch yet
+            return Poll::Ready(Ok(()));
         }
-
         let mut streams = self.streams.streams.lock().unwrap();
 
         let stream = &mut streams[idx];
@@ -414,25 +424,22 @@ impl SortPreservingMergeStream {
                 return Poll::Ready(Err(e));
             }
             Some(Ok(batch)) => {
-                let cursor = Arc::new(
-                    match SortKeyCursor::new(
-                        idx,
-                        self.next_batch_index, // assign this batch an ID
-                        Arc::new(batch),
-                        &self.column_expressions,
-                        self.sort_options.clone(),
-                    ) {
-                        Ok(cursor) => cursor,
-                        Err(e) => {
-                            return Poll::Ready(Err(ArrowError::ExternalError(
-                                Box::new(e),
-                            )));
-                        }
-                    },
-                );
-                self.next_batch_index += 1;
-                self.min_heap.push(cursor.clone());
-                self.cursors[idx].push_back(cursor)
+                let cursor = match SortKeyCursor::new(
+                    idx,
+                    self.next_batch_id, // assign this batch an ID
+                    &batch,
+                    &self.column_expressions,
+                    self.sort_options.clone(),
+                ) {
+                    Ok(cursor) => cursor,
+                    Err(e) => {
+                        return 
Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
+                    }
+                };
+                self.next_batch_id += 1;
+                self.min_heap.push(cursor);
+                self.cursor_finished[idx] = false;
+                self.batches[idx].push_back(batch)
             }
         }
 
@@ -441,15 +448,15 @@ impl SortPreservingMergeStream {
 
     /// Drains the in_progress row indexes, and builds a new RecordBatch from 
them
     ///
-    /// Will then drop any cursors for which all rows have been yielded to the 
output
+    /// Will then drop any batches for which all rows have been yielded to the 
output
     fn build_record_batch(&mut self) -> ArrowResult<RecordBatch> {
         // Mapping from stream index to the index of the first buffer from 
that stream
         let mut buffer_idx = 0;
-        let mut stream_to_buffer_idx = Vec::with_capacity(self.cursors.len());
+        let mut stream_to_buffer_idx = Vec::with_capacity(self.batches.len());
 
-        for cursors in &self.cursors {
+        for batches in &self.batches {
             stream_to_buffer_idx.push(buffer_idx);
-            buffer_idx += cursors.len();
+            buffer_idx += batches.len();
         }
 
         let columns = self
@@ -459,12 +466,10 @@ impl SortPreservingMergeStream {
             .enumerate()
             .map(|(column_idx, field)| {
                 let arrays = self
-                    .cursors
+                    .batches
                     .iter()
-                    .flat_map(|cursor| {
-                        cursor
-                            .iter()
-                            .map(|cursor| 
cursor.batch.column(column_idx).data())
+                    .flat_map(|batch| {
+                        batch.iter().map(|batch| 
batch.column(column_idx).data())
                     })
                     .collect();
 
@@ -480,13 +485,13 @@ impl SortPreservingMergeStream {
 
                 let first = &self.in_progress[0];
                 let mut buffer_idx =
-                    stream_to_buffer_idx[first.stream_idx] + first.cursor_idx;
+                    stream_to_buffer_idx[first.stream_idx] + first.batch_idx;
                 let mut start_row_idx = first.row_idx;
                 let mut end_row_idx = start_row_idx + 1;
 
                 for row_index in self.in_progress.iter().skip(1) {
                     let next_buffer_idx =
-                        stream_to_buffer_idx[row_index.stream_idx] + 
row_index.cursor_idx;
+                        stream_to_buffer_idx[row_index.stream_idx] + 
row_index.batch_idx;
 
                     if next_buffer_idx == buffer_idx && row_index.row_idx == 
end_row_idx {
                         // subsequent row in same batch
@@ -512,17 +517,17 @@ impl SortPreservingMergeStream {
         self.in_progress.clear();
 
         // New cursors are only created once the previous cursor for the stream
-        // is finished. This means all remaining rows from all but the last 
cursor
+        // is finished. This means all remaining rows from all but the last 
batch
         // for each stream have been yielded to the newly created record batch
         //
         // Additionally as `in_progress` has been drained, there are no longer
-        // any RowIndex's reliant on the cursor indexes
+        // any RowIndex's reliant on the batch indexes
         //
-        // We can therefore drop all but the last cursor for each stream
-        for cursors in &mut self.cursors {
-            if cursors.len() > 1 {
-                // Drain all but the last cursor
-                cursors.drain(0..(cursors.len() - 1));
+        // We can therefore drop all but the last batch for each stream
+        for batches in &mut self.batches {
+            if batches.len() > 1 {
+                // Drain all but the last batch
+                batches.drain(0..(batches.len() - 1));
             }
         }
 
@@ -554,7 +559,7 @@ impl SortPreservingMergeStream {
 
         // Ensure all non-exhausted streams have a cursor from which
         // rows can be pulled
-        for i in 0..self.cursors.len() {
+        for i in 0..self.streams.num_streams() {
             match futures::ready!(self.maybe_poll_stream(cx, i)) {
                 Ok(_) => {}
                 Err(e) => {
@@ -571,9 +576,9 @@ impl SortPreservingMergeStream {
             let _timer = elapsed_compute.timer();
 
             match self.min_heap.pop() {
-                Some(cursor) => {
+                Some(mut cursor) => {
                     let stream_idx = cursor.stream_idx;
-                    let cursor_idx = self.cursors[stream_idx].len() - 1;
+                    let batch_idx = self.batches[stream_idx].len() - 1;
                     let row_idx = cursor.advance();
 
                     let mut cursor_finished = false;
@@ -582,11 +587,12 @@ impl SortPreservingMergeStream {
                         self.min_heap.push(cursor);
                     } else {
                         cursor_finished = true;
+                        self.cursor_finished[stream_idx] = true;
                     }
 
                     self.in_progress.push(RowIndex {
                         stream_idx,
-                        cursor_idx,
+                        batch_idx,
                         row_idx,
                     });
 

Reply via email to