This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7d819d1  Consolidate sort and external_sort (#1596)
7d819d1 is described below

commit 7d819d1f5d24cc99d1ab30001f97fe15a8a7a802
Author: Yijie Shen <[email protected]>
AuthorDate: Fri Jan 21 22:45:33 2022 +0800

    Consolidate sort and external_sort (#1596)
    
    * Change SPMS to use heap sort, use SPMS instead of in-mem-sort as well
    
    * Incorporate metrics, external_sort pass all sort tests
    
    * Remove the original sort, substitute with external sort
    
    * Fix different batch_size setting in SPMS test
    
    * Change to use combine and sort for in memory N-way merge
    
    * Resolve comments on async and doc
    
    * Update sort to avoid deadlock during spilling
    
    * Fix spill hanging
---
 datafusion/src/execution/context.rs                |   1 +
 datafusion/src/physical_plan/common.rs             |  14 +-
 datafusion/src/physical_plan/explain.rs            |   5 +
 .../src/physical_plan/sorts/external_sort.rs       | 657 ---------------------
 datafusion/src/physical_plan/sorts/in_mem_sort.rs  | 241 --------
 datafusion/src/physical_plan/sorts/mod.rs          |  65 +-
 datafusion/src/physical_plan/sorts/sort.rs         | 606 +++++++++++++++----
 .../physical_plan/sorts/sort_preserving_merge.rs   | 182 +++---
 datafusion/tests/provider_filter_pushdown.rs       |   6 +-
 datafusion/tests/sql/joins.rs                      |  32 +-
 10 files changed, 641 insertions(+), 1168 deletions(-)

diff --git a/datafusion/src/execution/context.rs 
b/datafusion/src/execution/context.rs
index a3ca29f..ceea83d 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -1211,6 +1211,7 @@ impl FunctionRegistry for ExecutionContextState {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::execution::context::QueryPlanner;
     use crate::from_slice::FromSlice;
     use crate::logical_plan::plan::Projection;
     use crate::logical_plan::TableScan;
diff --git a/datafusion/src/physical_plan/common.rs 
b/datafusion/src/physical_plan/common.rs
index b199e63..cabb13a 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -20,6 +20,7 @@
 use super::{RecordBatchStream, SendableRecordBatchStream};
 use crate::error::{DataFusionError, Result};
 use crate::execution::runtime_env::RuntimeEnv;
+use crate::physical_plan::metrics::BaselineMetrics;
 use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics};
 use arrow::compute::concat;
 use arrow::datatypes::{Schema, SchemaRef};
@@ -41,15 +42,21 @@ pub struct SizedRecordBatchStream {
     schema: SchemaRef,
     batches: Vec<Arc<RecordBatch>>,
     index: usize,
+    baseline_metrics: BaselineMetrics,
 }
 
 impl SizedRecordBatchStream {
     /// Create a new RecordBatchIterator
-    pub fn new(schema: SchemaRef, batches: Vec<Arc<RecordBatch>>) -> Self {
+    pub fn new(
+        schema: SchemaRef,
+        batches: Vec<Arc<RecordBatch>>,
+        baseline_metrics: BaselineMetrics,
+    ) -> Self {
         SizedRecordBatchStream {
             schema,
             index: 0,
             batches,
+            baseline_metrics,
         }
     }
 }
@@ -61,12 +68,13 @@ impl Stream for SizedRecordBatchStream {
         mut self: std::pin::Pin<&mut Self>,
         _: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
-        Poll::Ready(if self.index < self.batches.len() {
+        let poll = Poll::Ready(if self.index < self.batches.len() {
             self.index += 1;
             Some(Ok(self.batches[self.index - 1].as_ref().clone()))
         } else {
             None
-        })
+        });
+        self.baseline_metrics.record_poll(poll)
     }
 }
 
diff --git a/datafusion/src/physical_plan/explain.rs 
b/datafusion/src/physical_plan/explain.rs
index df3dc98..f827dc3 100644
--- a/datafusion/src/physical_plan/explain.rs
+++ b/datafusion/src/physical_plan/explain.rs
@@ -32,6 +32,7 @@ use arrow::{array::StringBuilder, datatypes::SchemaRef, 
record_batch::RecordBatc
 
 use super::SendableRecordBatchStream;
 use crate::execution::runtime_env::RuntimeEnv;
+use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
 use async_trait::async_trait;
 
 /// Explain execution plan operator. This operator contains the string
@@ -146,9 +147,13 @@ impl ExecutionPlan for ExplainExec {
             ],
         )?;
 
+        let metrics = ExecutionPlanMetricsSet::new();
+        let baseline_metrics = BaselineMetrics::new(&metrics, partition);
+
         Ok(Box::pin(SizedRecordBatchStream::new(
             self.schema.clone(),
             vec![Arc::new(record_batch)],
+            baseline_metrics,
         )))
     }
 
diff --git a/datafusion/src/physical_plan/sorts/external_sort.rs 
b/datafusion/src/physical_plan/sorts/external_sort.rs
deleted file mode 100644
index 6c60aac..0000000
--- a/datafusion/src/physical_plan/sorts/external_sort.rs
+++ /dev/null
@@ -1,657 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Defines the External-Sort plan
-
-use crate::error::{DataFusionError, Result};
-use crate::execution::memory_manager::{
-    ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager,
-};
-use crate::execution::runtime_env::RuntimeEnv;
-use crate::physical_plan::common::{batch_byte_size, IPCWriter, 
SizedRecordBatchStream};
-use crate::physical_plan::expressions::PhysicalSortExpr;
-use crate::physical_plan::metrics::{
-    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
-};
-use crate::physical_plan::sorts::in_mem_sort::InMemSortStream;
-use crate::physical_plan::sorts::sort::sort_batch;
-use 
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
-use crate::physical_plan::sorts::SortedStream;
-use crate::physical_plan::stream::RecordBatchReceiverStream;
-use crate::physical_plan::{
-    DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
-    SendableRecordBatchStream, Statistics,
-};
-use arrow::datatypes::SchemaRef;
-use arrow::error::Result as ArrowResult;
-use arrow::ipc::reader::FileReader;
-use arrow::record_batch::RecordBatch;
-use async_trait::async_trait;
-use futures::lock::Mutex;
-use futures::StreamExt;
-use log::{error, info};
-use std::any::Any;
-use std::fmt;
-use std::fmt::{Debug, Formatter};
-use std::fs::File;
-use std::io::BufReader;
-use std::sync::atomic::{AtomicUsize, Ordering};
-use std::sync::Arc;
-use tokio::sync::mpsc::{Receiver as TKReceiver, Sender as TKSender};
-use tokio::task;
-
-/// Sort arbitrary size of data to get an total order (may spill several times 
during sorting based on free memory available).
-///
-/// The basic architecture of the algorithm:
-///
-/// let spills = vec![];
-/// let in_mem_batches = vec![];
-/// while (input.has_next()) {
-///     let batch = input.next();
-///     // no enough memory available, spill first.
-///     if exec_memory_available < size_of(batch) {
-///         let ordered_stream = in_mem_heap_sort(in_mem_batches.drain(..));
-///         let tmp_file = spill_write(ordered_stream);
-///         spills.push(tmp_file);
-///     }
-///     // sort the batch while it's probably still in cache and buffer it.
-///     let sorted = sort_by_key(batch);
-///     in_mem_batches.push(sorted);
-/// }
-///
-/// let partial_ordered_streams = vec![];
-/// let in_mem_stream = in_mem_heap_sort(in_mem_batches.drain(..));
-/// partial_ordered_streams.push(in_mem_stream);
-/// partial_ordered_streams.extend(spills.drain(..).map(read_as_stream));
-/// let result = sort_preserving_merge(partial_ordered_streams);
-struct ExternalSorter {
-    id: MemoryConsumerId,
-    schema: SchemaRef,
-    in_mem_batches: Mutex<Vec<RecordBatch>>,
-    spills: Mutex<Vec<String>>,
-    /// Sort expressions
-    expr: Vec<PhysicalSortExpr>,
-    runtime: Arc<RuntimeEnv>,
-    metrics: ExecutionPlanMetricsSet,
-    used: AtomicUsize,
-    spilled_bytes: AtomicUsize,
-    spilled_count: AtomicUsize,
-}
-
-impl ExternalSorter {
-    pub fn new(
-        partition_id: usize,
-        schema: SchemaRef,
-        expr: Vec<PhysicalSortExpr>,
-        runtime: Arc<RuntimeEnv>,
-    ) -> Self {
-        Self {
-            id: MemoryConsumerId::new(partition_id),
-            schema,
-            in_mem_batches: Mutex::new(vec![]),
-            spills: Mutex::new(vec![]),
-            expr,
-            runtime,
-            metrics: ExecutionPlanMetricsSet::new(),
-            used: AtomicUsize::new(0),
-            spilled_bytes: AtomicUsize::new(0),
-            spilled_count: AtomicUsize::new(0),
-        }
-    }
-
-    async fn insert_batch(&self, input: RecordBatch) -> Result<()> {
-        let size = batch_byte_size(&input);
-        self.try_grow(size).await?;
-        self.used.fetch_add(size, Ordering::SeqCst);
-        // sort each batch as it's inserted, more probably to be cache-resident
-        let sorted_batch = sort_batch(input, self.schema.clone(), 
&*self.expr)?;
-        let mut in_mem_batches = self.in_mem_batches.lock().await;
-        in_mem_batches.push(sorted_batch);
-        Ok(())
-    }
-
-    /// MergeSort in mem batches as well as spills into total order with 
`SortPreservingMergeStream`.
-    async fn sort(&self) -> Result<SendableRecordBatchStream> {
-        let partition = self.partition_id();
-        let mut in_mem_batches = self.in_mem_batches.lock().await;
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-        let mut streams: Vec<SortedStream> = vec![];
-        let in_mem_stream = in_mem_partial_sort(
-            &mut *in_mem_batches,
-            self.schema.clone(),
-            &self.expr,
-            self.runtime.batch_size(),
-            baseline_metrics,
-        )
-        .await?;
-        streams.push(SortedStream::new(in_mem_stream, self.used()));
-
-        let mut spills = self.spills.lock().await;
-
-        for spill in spills.drain(..) {
-            let stream = read_spill_as_stream(spill, 
self.schema.clone()).await?;
-            streams.push(SortedStream::new(stream, 0));
-        }
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-
-        Ok(Box::pin(
-            SortPreservingMergeStream::new_from_stream(
-                streams,
-                self.schema.clone(),
-                &self.expr,
-                baseline_metrics,
-                partition,
-                self.runtime.clone(),
-            )
-            .await,
-        ))
-    }
-
-    fn used(&self) -> usize {
-        self.used.load(Ordering::SeqCst)
-    }
-
-    fn spilled_bytes(&self) -> usize {
-        self.spilled_bytes.load(Ordering::SeqCst)
-    }
-
-    fn spilled_count(&self) -> usize {
-        self.spilled_count.load(Ordering::SeqCst)
-    }
-}
-
-impl Debug for ExternalSorter {
-    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        f.debug_struct("ExternalSorter")
-            .field("id", &self.id())
-            .field("memory_used", &self.used())
-            .field("spilled_bytes", &self.spilled_bytes())
-            .field("spilled_count", &self.spilled_count())
-            .finish()
-    }
-}
-
-#[async_trait]
-impl MemoryConsumer for ExternalSorter {
-    fn name(&self) -> String {
-        "ExternalSorter".to_owned()
-    }
-
-    fn id(&self) -> &MemoryConsumerId {
-        &self.id
-    }
-
-    fn memory_manager(&self) -> Arc<MemoryManager> {
-        self.runtime.memory_manager.clone()
-    }
-
-    fn type_(&self) -> &ConsumerType {
-        &ConsumerType::Requesting
-    }
-
-    async fn spill(&self) -> Result<usize> {
-        info!(
-            "{}[{}] spilling sort data of {} to disk while inserting ({} 
time(s) so far)",
-            self.name(),
-            self.id(),
-            self.used(),
-            self.spilled_count()
-        );
-
-        let partition = self.partition_id();
-        let mut in_mem_batches = self.in_mem_batches.lock().await;
-        // we could always get a chance to free some memory as long as we are 
holding some
-        if in_mem_batches.len() == 0 {
-            return Ok(0);
-        }
-
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-
-        let path = self.runtime.disk_manager.create_tmp_file()?;
-        let stream = in_mem_partial_sort(
-            &mut *in_mem_batches,
-            self.schema.clone(),
-            &*self.expr,
-            self.runtime.batch_size(),
-            baseline_metrics,
-        )
-        .await;
-
-        let total_size =
-            spill_partial_sorted_stream(&mut stream?, path.clone(), 
self.schema.clone())
-                .await?;
-
-        let mut spills = self.spills.lock().await;
-        let used = self.used.swap(0, Ordering::SeqCst);
-        self.spilled_count.fetch_add(1, Ordering::SeqCst);
-        self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst);
-        spills.push(path);
-        Ok(used)
-    }
-
-    fn mem_used(&self) -> usize {
-        self.used.load(Ordering::SeqCst)
-    }
-}
-
-/// consume the `sorted_bathes` and do in_mem_sort
-async fn in_mem_partial_sort(
-    sorted_bathes: &mut Vec<RecordBatch>,
-    schema: SchemaRef,
-    expressions: &[PhysicalSortExpr],
-    target_batch_size: usize,
-    baseline_metrics: BaselineMetrics,
-) -> Result<SendableRecordBatchStream> {
-    if sorted_bathes.len() == 1 {
-        Ok(Box::pin(SizedRecordBatchStream::new(
-            schema,
-            vec![Arc::new(sorted_bathes.pop().unwrap())],
-        )))
-    } else {
-        let new = sorted_bathes.drain(..).collect();
-        assert_eq!(sorted_bathes.len(), 0);
-        Ok(Box::pin(InMemSortStream::new(
-            new,
-            schema,
-            expressions,
-            target_batch_size,
-            baseline_metrics,
-        )?))
-    }
-}
-
-async fn spill_partial_sorted_stream(
-    in_mem_stream: &mut SendableRecordBatchStream,
-    path: String,
-    schema: SchemaRef,
-) -> Result<usize> {
-    let (sender, receiver) = tokio::sync::mpsc::channel(2);
-    while let Some(item) = in_mem_stream.next().await {
-        sender.send(Some(item)).await.ok();
-    }
-    sender.send(None).await.ok();
-    let path_clone = path.clone();
-    let res =
-        task::spawn_blocking(move || write_sorted(receiver, path_clone, 
schema)).await;
-    match res {
-        Ok(r) => r,
-        Err(e) => Err(DataFusionError::Execution(format!(
-            "Error occurred while spilling {}",
-            e
-        ))),
-    }
-}
-
-async fn read_spill_as_stream(
-    path: String,
-    schema: SchemaRef,
-) -> Result<SendableRecordBatchStream> {
-    let (sender, receiver): (
-        TKSender<ArrowResult<RecordBatch>>,
-        TKReceiver<ArrowResult<RecordBatch>>,
-    ) = tokio::sync::mpsc::channel(2);
-    let path_clone = path.clone();
-    let join_handle = task::spawn_blocking(move || {
-        if let Err(e) = read_spill(sender, path_clone) {
-            error!("Failure while reading spill file: {}. Error: {}", path, e);
-        }
-    });
-    Ok(RecordBatchReceiverStream::create(
-        &schema,
-        receiver,
-        join_handle,
-    ))
-}
-
-fn write_sorted(
-    mut receiver: TKReceiver<Option<ArrowResult<RecordBatch>>>,
-    path: String,
-    schema: SchemaRef,
-) -> Result<usize> {
-    let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?;
-    while let Some(Some(batch)) = receiver.blocking_recv() {
-        writer.write(&batch?)?;
-    }
-    writer.finish()?;
-    info!(
-        "Spilled {} batches of total {} rows to disk, memory released {}",
-        writer.num_batches, writer.num_rows, writer.num_bytes
-    );
-    Ok(writer.num_bytes as usize)
-}
-
-fn read_spill(sender: TKSender<ArrowResult<RecordBatch>>, path: String) -> 
Result<()> {
-    let file = BufReader::new(File::open(&path)?);
-    let reader = FileReader::try_new(file)?;
-    for batch in reader {
-        sender
-            .blocking_send(batch)
-            .map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
-    }
-    Ok(())
-}
-
-/// External Sort execution plan
-#[derive(Debug)]
-pub struct ExternalSortExec {
-    /// Input schema
-    input: Arc<dyn ExecutionPlan>,
-    /// Sort expressions
-    expr: Vec<PhysicalSortExpr>,
-    /// Execution metrics
-    metrics: ExecutionPlanMetricsSet,
-    /// Preserve partitions of input plan
-    preserve_partitioning: bool,
-}
-
-impl ExternalSortExec {
-    /// Create a new sort execution plan
-    pub fn try_new(
-        expr: Vec<PhysicalSortExpr>,
-        input: Arc<dyn ExecutionPlan>,
-    ) -> Result<Self> {
-        Ok(Self::new_with_partitioning(expr, input, false))
-    }
-
-    /// Create a new sort execution plan with the option to preserve
-    /// the partitioning of the input plan
-    pub fn new_with_partitioning(
-        expr: Vec<PhysicalSortExpr>,
-        input: Arc<dyn ExecutionPlan>,
-        preserve_partitioning: bool,
-    ) -> Self {
-        Self {
-            expr,
-            input,
-            metrics: ExecutionPlanMetricsSet::new(),
-            preserve_partitioning,
-        }
-    }
-
-    /// Input schema
-    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
-        &self.input
-    }
-
-    /// Sort expressions
-    pub fn expr(&self) -> &[PhysicalSortExpr] {
-        &self.expr
-    }
-}
-
-#[async_trait]
-impl ExecutionPlan for ExternalSortExec {
-    fn as_any(&self) -> &dyn Any {
-        self
-    }
-
-    fn schema(&self) -> SchemaRef {
-        self.input.schema()
-    }
-
-    /// Get the output partitioning of this plan
-    fn output_partitioning(&self) -> Partitioning {
-        if self.preserve_partitioning {
-            self.input.output_partitioning()
-        } else {
-            Partitioning::UnknownPartitioning(1)
-        }
-    }
-
-    fn required_child_distribution(&self) -> Distribution {
-        if self.preserve_partitioning {
-            Distribution::UnspecifiedDistribution
-        } else {
-            Distribution::SinglePartition
-        }
-    }
-
-    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
-        vec![self.input.clone()]
-    }
-
-    fn with_new_children(
-        &self,
-        children: Vec<Arc<dyn ExecutionPlan>>,
-    ) -> Result<Arc<dyn ExecutionPlan>> {
-        match children.len() {
-            1 => Ok(Arc::new(ExternalSortExec::try_new(
-                self.expr.clone(),
-                children[0].clone(),
-            )?)),
-            _ => Err(DataFusionError::Internal(
-                "SortExec wrong number of children".to_string(),
-            )),
-        }
-    }
-
-    async fn execute(
-        &self,
-        partition: usize,
-        runtime: Arc<RuntimeEnv>,
-    ) -> Result<SendableRecordBatchStream> {
-        if !self.preserve_partitioning {
-            if 0 != partition {
-                return Err(DataFusionError::Internal(format!(
-                    "SortExec invalid partition {}",
-                    partition
-                )));
-            }
-
-            // sort needs to operate on a single partition currently
-            if 1 != self.input.output_partitioning().partition_count() {
-                return Err(DataFusionError::Internal(
-                    "SortExec requires a single input partition".to_owned(),
-                ));
-            }
-        }
-
-        let _baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-        let input = self.input.execute(partition, runtime.clone()).await?;
-
-        external_sort(input, partition, self.expr.clone(), runtime).await
-    }
-
-    fn metrics(&self) -> Option<MetricsSet> {
-        Some(self.metrics.clone_inner())
-    }
-
-    fn fmt_as(
-        &self,
-        t: DisplayFormatType,
-        f: &mut std::fmt::Formatter,
-    ) -> std::fmt::Result {
-        match t {
-            DisplayFormatType::Default => {
-                let expr: Vec<String> = self.expr.iter().map(|e| 
e.to_string()).collect();
-                write!(f, "SortExec: [{}]", expr.join(","))
-            }
-        }
-    }
-
-    fn statistics(&self) -> Statistics {
-        self.input.statistics()
-    }
-}
-
-async fn external_sort(
-    mut input: SendableRecordBatchStream,
-    partition_id: usize,
-    expr: Vec<PhysicalSortExpr>,
-    runtime: Arc<RuntimeEnv>,
-) -> Result<SendableRecordBatchStream> {
-    let schema = input.schema();
-    let sorter = Arc::new(ExternalSorter::new(
-        partition_id,
-        schema.clone(),
-        expr,
-        runtime.clone(),
-    ));
-    runtime.register_consumer(&(sorter.clone() as Arc<dyn MemoryConsumer>));
-
-    while let Some(batch) = input.next().await {
-        let batch = batch?;
-        sorter.insert_batch(batch).await?;
-    }
-
-    let result = sorter.sort().await;
-    runtime.drop_consumer(sorter.id());
-    result
-}
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use crate::datasource::object_store::local::LocalFileSystem;
-    use crate::execution::runtime_env::RuntimeConfig;
-    use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
-    use crate::physical_plan::expressions::col;
-    use crate::physical_plan::{
-        collect,
-        file_format::{CsvExec, FileScanConfig},
-    };
-    use crate::test;
-    use crate::test_util;
-    use arrow::array::*;
-    use arrow::compute::SortOptions;
-    use arrow::datatypes::*;
-
-    async fn sort_with_runtime(runtime: Arc<RuntimeEnv>) -> 
Result<Vec<RecordBatch>> {
-        let schema = test_util::aggr_test_schema();
-        let partitions = 4;
-        let (_, files) =
-            test::create_partitioned_csv("aggregate_test_100.csv", 
partitions)?;
-
-        let csv = CsvExec::new(
-            FileScanConfig {
-                object_store: Arc::new(LocalFileSystem {}),
-                file_schema: Arc::clone(&schema),
-                file_groups: files,
-                statistics: Statistics::default(),
-                projection: None,
-                limit: None,
-                table_partition_cols: vec![],
-            },
-            true,
-            b',',
-        );
-
-        let sort_exec = Arc::new(ExternalSortExec::try_new(
-            vec![
-                // c1 string column
-                PhysicalSortExpr {
-                    expr: col("c1", &schema)?,
-                    options: SortOptions::default(),
-                },
-                // c2 uin32 column
-                PhysicalSortExpr {
-                    expr: col("c2", &schema)?,
-                    options: SortOptions::default(),
-                },
-                // c7 uin8 column
-                PhysicalSortExpr {
-                    expr: col("c7", &schema)?,
-                    options: SortOptions::default(),
-                },
-            ],
-            Arc::new(CoalescePartitionsExec::new(Arc::new(csv))),
-        )?);
-
-        collect(sort_exec, runtime).await
-    }
-
-    #[tokio::test]
-    async fn test_in_mem_sort() -> Result<()> {
-        let runtime = Arc::new(RuntimeEnv::default());
-        let result = sort_with_runtime(runtime).await?;
-
-        assert_eq!(result.len(), 1);
-
-        let columns = result[0].columns();
-
-        let c1 = as_string_array(&columns[0]);
-        assert_eq!(c1.value(0), "a");
-        assert_eq!(c1.value(c1.len() - 1), "e");
-
-        let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
-        assert_eq!(c2.value(0), 1);
-        assert_eq!(c2.value(c2.len() - 1), 5,);
-
-        let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
-        assert_eq!(c7.value(0), 15);
-        assert_eq!(c7.value(c7.len() - 1), 254,);
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_sort_spill() -> Result<()> {
-        let config = RuntimeConfig::new()
-            .with_memory_fraction(1.0)
-            // trigger spill there will be 4 batches with 5.5KB for each
-            .with_max_execution_memory(12288);
-        let runtime = Arc::new(RuntimeEnv::new(config)?);
-        let result = sort_with_runtime(runtime).await?;
-
-        assert_eq!(result.len(), 1);
-
-        let columns = result[0].columns();
-
-        let c1 = as_string_array(&columns[0]);
-        assert_eq!(c1.value(0), "a");
-        assert_eq!(c1.value(c1.len() - 1), "e");
-
-        let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
-        assert_eq!(c2.value(0), 1);
-        assert_eq!(c2.value(c2.len() - 1), 5,);
-
-        let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
-        assert_eq!(c7.value(0), 15);
-        assert_eq!(c7.value(c7.len() - 1), 254,);
-
-        Ok(())
-    }
-
-    #[tokio::test]
-    async fn test_multi_output_batch() -> Result<()> {
-        let config = RuntimeConfig::new().with_batch_size(26);
-        let runtime = Arc::new(RuntimeEnv::new(config)?);
-        let result = sort_with_runtime(runtime).await?;
-
-        assert_eq!(result.len(), 4);
-
-        let columns_b1 = result[0].columns();
-        let columns_b3 = result[3].columns();
-
-        let c1 = as_string_array(&columns_b1[0]);
-        let c13 = as_string_array(&columns_b3[0]);
-        assert_eq!(c1.value(0), "a");
-        assert_eq!(c13.value(c13.len() - 1), "e");
-
-        let c2 = as_primitive_array::<UInt32Type>(&columns_b1[1]);
-        let c23 = as_primitive_array::<UInt32Type>(&columns_b3[1]);
-        assert_eq!(c2.value(0), 1);
-        assert_eq!(c23.value(c23.len() - 1), 5,);
-
-        let c7 = as_primitive_array::<UInt8Type>(&columns_b1[6]);
-        let c73 = as_primitive_array::<UInt8Type>(&columns_b3[6]);
-        assert_eq!(c7.value(0), 15);
-        assert_eq!(c73.value(c73.len() - 1), 254,);
-
-        Ok(())
-    }
-}
diff --git a/datafusion/src/physical_plan/sorts/in_mem_sort.rs 
b/datafusion/src/physical_plan/sorts/in_mem_sort.rs
deleted file mode 100644
index 9e7753d..0000000
--- a/datafusion/src/physical_plan/sorts/in_mem_sort.rs
+++ /dev/null
@@ -1,241 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-use std::collections::BinaryHeap;
-use std::pin::Pin;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
-use arrow::{
-    array::{make_array as make_arrow_array, MutableArrayData},
-    compute::SortOptions,
-    datatypes::SchemaRef,
-    error::{ArrowError, Result as ArrowResult},
-    record_batch::RecordBatch,
-};
-use futures::Stream;
-
-use crate::error::Result;
-use crate::physical_plan::metrics::BaselineMetrics;
-use crate::physical_plan::sorts::{RowIndex, SortKeyCursor};
-use crate::physical_plan::{
-    expressions::PhysicalSortExpr, PhysicalExpr, RecordBatchStream,
-};
-
-/// Merge buffered, self-sorted record batches to get an order.
-///
-/// Internally, it uses MinHeap to reduce extra memory consumption
-/// by not concatenating all batches into one and sorting it as done by 
`SortExec`.
-pub(crate) struct InMemSortStream {
-    /// The schema of the RecordBatches yielded by this stream
-    schema: SchemaRef,
-    /// Self sorted batches to be merged together
-    batches: Vec<Arc<RecordBatch>>,
-    /// The accumulated row indexes for the next record batch
-    in_progress: Vec<RowIndex>,
-    /// The desired RecordBatch size to yield
-    target_batch_size: usize,
-    /// used to record execution metrics
-    baseline_metrics: BaselineMetrics,
-    /// If the stream has encountered an error
-    aborted: bool,
-    /// min heap for record comparison
-    min_heap: BinaryHeap<SortKeyCursor>,
-}
-
-impl InMemSortStream {
-    pub(crate) fn new(
-        sorted_batches: Vec<RecordBatch>,
-        schema: SchemaRef,
-        expressions: &[PhysicalSortExpr],
-        target_batch_size: usize,
-        baseline_metrics: BaselineMetrics,
-    ) -> Result<Self> {
-        let len = sorted_batches.len();
-        let mut cursors = Vec::with_capacity(len);
-        let mut min_heap = BinaryHeap::with_capacity(len);
-
-        let column_expressions: Vec<Arc<dyn PhysicalExpr>> =
-            expressions.iter().map(|x| x.expr.clone()).collect();
-
-        // The sort options for each expression
-        let sort_options: Arc<Vec<SortOptions>> =
-            Arc::new(expressions.iter().map(|x| x.options).collect());
-
-        sorted_batches
-            .into_iter()
-            .enumerate()
-            .try_for_each(|(idx, batch)| {
-                let batch = Arc::new(batch);
-                let cursor = match SortKeyCursor::new(
-                    idx,
-                    batch.clone(),
-                    &column_expressions,
-                    sort_options.clone(),
-                ) {
-                    Ok(cursor) => cursor,
-                    Err(e) => return Err(e),
-                };
-                min_heap.push(cursor);
-                cursors.insert(idx, batch);
-                Ok(())
-            })?;
-
-        Ok(Self {
-            schema,
-            batches: cursors,
-            target_batch_size,
-            baseline_metrics,
-            aborted: false,
-            in_progress: vec![],
-            min_heap,
-        })
-    }
-
-    /// Returns the index of the next batch to pull a row from, or None
-    /// if all cursors for all batch are exhausted
-    fn next_cursor(&mut self) -> Result<Option<SortKeyCursor>> {
-        match self.min_heap.pop() {
-            None => Ok(None),
-            Some(cursor) => Ok(Some(cursor)),
-        }
-    }
-
-    /// Drains the in_progress row indexes, and builds a new RecordBatch from 
them
-    ///
-    /// Will then drop any cursors for which all rows have been yielded to the 
output
-    fn build_record_batch(&mut self) -> ArrowResult<RecordBatch> {
-        let columns = self
-            .schema
-            .fields()
-            .iter()
-            .enumerate()
-            .map(|(column_idx, field)| {
-                let arrays = self
-                    .batches
-                    .iter()
-                    .map(|batch| batch.column(column_idx).data())
-                    .collect();
-
-                let mut array_data = MutableArrayData::new(
-                    arrays,
-                    field.is_nullable(),
-                    self.in_progress.len(),
-                );
-
-                if self.in_progress.is_empty() {
-                    return make_arrow_array(array_data.freeze());
-                }
-
-                let first = &self.in_progress[0];
-                let mut buffer_idx = first.stream_idx;
-                let mut start_row_idx = first.row_idx;
-                let mut end_row_idx = start_row_idx + 1;
-
-                for row_index in self.in_progress.iter().skip(1) {
-                    let next_buffer_idx = row_index.stream_idx;
-
-                    if next_buffer_idx == buffer_idx && row_index.row_idx == 
end_row_idx {
-                        // subsequent row in same batch
-                        end_row_idx += 1;
-                        continue;
-                    }
-
-                    // emit current batch of rows for current buffer
-                    array_data.extend(buffer_idx, start_row_idx, end_row_idx);
-
-                    // start new batch of rows
-                    buffer_idx = next_buffer_idx;
-                    start_row_idx = row_index.row_idx;
-                    end_row_idx = start_row_idx + 1;
-                }
-
-                // emit final batch of rows
-                array_data.extend(buffer_idx, start_row_idx, end_row_idx);
-                make_arrow_array(array_data.freeze())
-            })
-            .collect();
-
-        self.in_progress.clear();
-        RecordBatch::try_new(self.schema.clone(), columns)
-    }
-
-    #[inline]
-    fn poll_next_inner(
-        self: &mut Pin<&mut Self>,
-        _cx: &mut Context<'_>,
-    ) -> Poll<Option<ArrowResult<RecordBatch>>> {
-        if self.aborted {
-            return Poll::Ready(None);
-        }
-
-        loop {
-            // NB timer records time taken on drop, so there are no
-            // calls to `timer.done()` below.
-            let elapsed_compute = 
self.baseline_metrics.elapsed_compute().clone();
-            let _timer = elapsed_compute.timer();
-
-            match self.next_cursor() {
-                Ok(Some(mut cursor)) => {
-                    let batch_idx = cursor.batch_idx;
-                    let row_idx = cursor.advance();
-
-                    // insert the cursor back to min_heap if the record batch 
is not exhausted
-                    if !cursor.is_finished() {
-                        self.min_heap.push(cursor);
-                    }
-
-                    self.in_progress.push(RowIndex {
-                        stream_idx: batch_idx,
-                        cursor_idx: 0,
-                        row_idx,
-                    });
-                }
-                Ok(None) if self.in_progress.is_empty() => return 
Poll::Ready(None),
-                Ok(None) => return 
Poll::Ready(Some(self.build_record_batch())),
-                Err(e) => {
-                    self.aborted = true;
-                    return 
Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new(
-                        e,
-                    )))));
-                }
-            };
-
-            if self.in_progress.len() == self.target_batch_size {
-                return Poll::Ready(Some(self.build_record_batch()));
-            }
-        }
-    }
-}
-
-impl Stream for InMemSortStream {
-    type Item = ArrowResult<RecordBatch>;
-
-    fn poll_next(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Option<Self::Item>> {
-        let poll = self.poll_next_inner(cx);
-        self.baseline_metrics.record_poll(poll)
-    }
-}
-
-impl RecordBatchStream for InMemSortStream {
-    fn schema(&self) -> SchemaRef {
-        self.schema.clone()
-    }
-}
diff --git a/datafusion/src/physical_plan/sorts/mod.rs 
b/datafusion/src/physical_plan/sorts/mod.rs
index 3dda13b..1bb880f 100644
--- a/datafusion/src/physical_plan/sorts/mod.rs
+++ b/datafusion/src/physical_plan/sorts/mod.rs
@@ -32,11 +32,10 @@ use std::borrow::BorrowMut;
 use std::cmp::Ordering;
 use std::fmt::{Debug, Formatter};
 use std::pin::Pin;
+use std::sync::atomic::AtomicUsize;
 use std::sync::{Arc, RwLock};
 use std::task::{Context, Poll};
 
-pub mod external_sort;
-mod in_mem_sort;
 pub mod sort;
 pub mod sort_preserving_merge;
 
@@ -50,8 +49,9 @@ pub mod sort_preserving_merge;
 /// by this row cursor, with that of another `SortKeyCursor`. A cursor stores
 /// a row comparator for each other cursor that it is compared to.
 struct SortKeyCursor {
-    columns: Vec<ArrayRef>,
-    cur_row: usize,
+    stream_idx: usize,
+    sort_columns: Vec<ArrayRef>,
+    cur_row: AtomicUsize,
     num_rows: usize,
 
     // An index uniquely identifying the record batch scanned by this cursor.
@@ -68,8 +68,8 @@ struct SortKeyCursor {
 impl<'a> std::fmt::Debug for SortKeyCursor {
     fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
         f.debug_struct("SortKeyCursor")
-            .field("columns", &self.columns)
-            .field("cur_row", &self.cur_row)
+            .field("sort_columns", &self.sort_columns)
+            .field("cur_row", &self.cur_row())
             .field("num_rows", &self.num_rows)
             .field("batch_idx", &self.batch_idx)
             .field("batch", &self.batch)
@@ -80,19 +80,21 @@ impl<'a> std::fmt::Debug for SortKeyCursor {
 
 impl SortKeyCursor {
     fn new(
+        stream_idx: usize,
         batch_idx: usize,
         batch: Arc<RecordBatch>,
         sort_key: &[Arc<dyn PhysicalExpr>],
         sort_options: Arc<Vec<SortOptions>>,
     ) -> error::Result<Self> {
-        let columns = sort_key
+        let sort_columns = sort_key
             .iter()
             .map(|expr| 
Ok(expr.evaluate(&batch)?.into_array(batch.num_rows())))
             .collect::<error::Result<_>>()?;
         Ok(Self {
-            cur_row: 0,
+            stream_idx,
+            cur_row: AtomicUsize::new(0),
             num_rows: batch.num_rows(),
-            columns,
+            sort_columns,
             batch,
             batch_idx,
             batch_comparators: RwLock::new(HashMap::new()),
@@ -101,38 +103,41 @@ impl SortKeyCursor {
     }
 
     fn is_finished(&self) -> bool {
-        self.num_rows == self.cur_row
+        self.num_rows == self.cur_row()
     }
 
-    fn advance(&mut self) -> usize {
+    fn advance(&self) -> usize {
         assert!(!self.is_finished());
-        let t = self.cur_row;
-        self.cur_row += 1;
-        t
+        self.cur_row
+            .fetch_add(1, std::sync::atomic::Ordering::SeqCst)
+    }
+
+    fn cur_row(&self) -> usize {
+        self.cur_row.load(std::sync::atomic::Ordering::SeqCst)
     }
 
     /// Compares the sort key pointed to by this instance's row cursor with 
that of another
     fn compare(&self, other: &SortKeyCursor) -> error::Result<Ordering> {
-        if self.columns.len() != other.columns.len() {
+        if self.sort_columns.len() != other.sort_columns.len() {
             return Err(DataFusionError::Internal(format!(
                 "SortKeyCursors had inconsistent column counts: {} vs {}",
-                self.columns.len(),
-                other.columns.len()
+                self.sort_columns.len(),
+                other.sort_columns.len()
             )));
         }
 
-        if self.columns.len() != self.sort_options.len() {
+        if self.sort_columns.len() != self.sort_options.len() {
             return Err(DataFusionError::Internal(format!(
                 "Incorrect number of SortOptions provided to 
SortKeyCursor::compare, expected {} got {}",
-                self.columns.len(),
+                self.sort_columns.len(),
                 self.sort_options.len()
             )));
         }
 
         let zipped: Vec<((&ArrayRef, &ArrayRef), &SortOptions)> = self
-            .columns
+            .sort_columns
             .iter()
-            .zip(other.columns.iter())
+            .zip(other.sort_columns.iter())
             .zip(self.sort_options.iter())
             .collect::<Vec<_>>();
 
@@ -146,7 +151,7 @@ impl SortKeyCursor {
         })?;
 
         for (i, ((l, r), sort_options)) in zipped.iter().enumerate() {
-            match (l.is_valid(self.cur_row), r.is_valid(other.cur_row)) {
+            match (l.is_valid(self.cur_row()), r.is_valid(other.cur_row())) {
                 (false, true) if sort_options.nulls_first => return 
Ok(Ordering::Less),
                 (false, true) => return Ok(Ordering::Greater),
                 (true, false) if sort_options.nulls_first => {
@@ -154,7 +159,7 @@ impl SortKeyCursor {
                 }
                 (true, false) => return Ok(Ordering::Less),
                 (false, false) => {}
-                (true, true) => match cmp[i](self.cur_row, other.cur_row) {
+                (true, true) => match cmp[i](self.cur_row(), other.cur_row()) {
                     Ordering::Equal => {}
                     o if sort_options.descending => return Ok(o.reverse()),
                     o => return Ok(o),
@@ -179,7 +184,7 @@ impl SortKeyCursor {
             let cmp = map
                 .borrow_mut()
                 .entry(other.batch_idx)
-                .or_insert_with(|| Vec::with_capacity(other.columns.len()));
+                .or_insert_with(|| 
Vec::with_capacity(other.sort_columns.len()));
 
             for (i, ((l, r), _)) in zipped.iter().enumerate() {
                 if i >= cmp.len() {
@@ -193,7 +198,7 @@ impl SortKeyCursor {
 }
 
 impl Ord for SortKeyCursor {
-    /// Needed by min-heap comparison in `in_mem_sort` and reverse the order 
at the same time.
+    /// Needed by min-heap comparison and reverse the order at the same time.
     fn cmp(&self, other: &Self) -> Ordering {
         other.compare(self).unwrap()
     }
@@ -219,8 +224,7 @@ impl PartialOrd for SortKeyCursor {
 struct RowIndex {
     /// The index of the stream
     stream_idx: usize,
-    /// For sort_preserving_merge, it's the index of the cursor within the 
stream's VecDequeue.
-    /// For in_mem_sort which have only one batch for each stream, cursor_idx 
always 0
+    /// The index of the cursor within the stream's VecDequeue.
     cursor_idx: usize,
     /// The row index
     row_idx: usize,
@@ -251,10 +255,9 @@ enum StreamWrapper {
 
 impl StreamWrapper {
     fn mem_used(&self) -> usize {
-        if let StreamWrapper::Stream(Some(s)) = &self {
-            s.mem_used
-        } else {
-            0
+        match &self {
+            StreamWrapper::Stream(Some(s)) => s.mem_used,
+            _ => 0,
         }
     }
 }
diff --git a/datafusion/src/physical_plan/sorts/sort.rs 
b/datafusion/src/physical_plan/sorts/sort.rs
index a210b93..c3a138e 100644
--- a/datafusion/src/physical_plan/sorts/sort.rs
+++ b/datafusion/src/physical_plan/sorts/sort.rs
@@ -15,47 +15,433 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! Defines the SORT plan
+//! Sort that deals with an arbitrary size of the input.
+//! It will do in-memory sorting if it has enough memory budget
+//! but spills to disk if needed.
 
 use crate::error::{DataFusionError, Result};
+use crate::execution::memory_manager::{
+    ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager,
+};
 use crate::execution::runtime_env::RuntimeEnv;
-use crate::physical_plan::common::AbortOnDropSingle;
+use crate::physical_plan::common::{batch_byte_size, IPCWriter, 
SizedRecordBatchStream};
 use crate::physical_plan::expressions::PhysicalSortExpr;
 use crate::physical_plan::metrics::{
-    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
+    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricsSet, Time,
 };
+use 
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
+use crate::physical_plan::sorts::SortedStream;
+use crate::physical_plan::stream::RecordBatchReceiverStream;
 use crate::physical_plan::{
-    common, DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
+    common, DisplayFormatType, Distribution, EmptyRecordBatchStream, 
ExecutionPlan,
+    Partitioning, SendableRecordBatchStream, Statistics,
 };
-use crate::physical_plan::{RecordBatchStream, SendableRecordBatchStream, 
Statistics};
+use arrow::array::ArrayRef;
 pub use arrow::compute::SortOptions;
 use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions};
 use arrow::datatypes::SchemaRef;
 use arrow::error::Result as ArrowResult;
+use arrow::ipc::reader::FileReader;
 use arrow::record_batch::RecordBatch;
-use arrow::{array::ArrayRef, error::ArrowError};
 use async_trait::async_trait;
-use futures::stream::Stream;
-use futures::Future;
-use pin_project_lite::pin_project;
+use futures::lock::Mutex;
+use futures::StreamExt;
+use log::{error, info};
 use std::any::Any;
-use std::pin::Pin;
+use std::fmt;
+use std::fmt::{Debug, Formatter};
+use std::fs::File;
+use std::io::BufReader;
+use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::Arc;
-use std::task::{Context, Poll};
+use std::time::Duration;
+use tokio::sync::mpsc::{Receiver as TKReceiver, Sender as TKSender};
+use tokio::task;
+
+/// Sort arbitrary size of data to get a total order (may spill several times 
during sorting based on free memory available).
+///
+/// The basic architecture of the algorithm:
+/// 1. get a non-empty new batch from input
+/// 2. check with the memory manager if we could buffer the batch in memory
+/// 2.1 if memory sufficient, then buffer batch in memory, go to 1.
+/// 2.2 if the memory threshold is reached, sort all buffered batches and 
spill to file.
+///     buffer the batch in memory, go to 1.
+/// 3. when input is exhausted, merge all in memory batches and spills to get 
a total order.
+struct ExternalSorter {
+    id: MemoryConsumerId,
+    schema: SchemaRef,
+    in_mem_batches: Mutex<Vec<RecordBatch>>,
+    spills: Mutex<Vec<String>>,
+    /// Sort expressions
+    expr: Vec<PhysicalSortExpr>,
+    runtime: Arc<RuntimeEnv>,
+    metrics: AggregatedMetricsSet,
+    used: AtomicUsize,
+    spilled_bytes: AtomicUsize,
+    spilled_count: AtomicUsize,
+}
+
+impl ExternalSorter {
+    pub fn new(
+        partition_id: usize,
+        schema: SchemaRef,
+        expr: Vec<PhysicalSortExpr>,
+        metrics: AggregatedMetricsSet,
+        runtime: Arc<RuntimeEnv>,
+    ) -> Self {
+        Self {
+            id: MemoryConsumerId::new(partition_id),
+            schema,
+            in_mem_batches: Mutex::new(vec![]),
+            spills: Mutex::new(vec![]),
+            expr,
+            runtime,
+            metrics,
+            used: AtomicUsize::new(0),
+            spilled_bytes: AtomicUsize::new(0),
+            spilled_count: AtomicUsize::new(0),
+        }
+    }
+
+    async fn insert_batch(&self, input: RecordBatch) -> Result<()> {
+        if input.num_rows() > 0 {
+            let size = batch_byte_size(&input);
+            self.try_grow(size).await?;
+            self.used.fetch_add(size, Ordering::SeqCst);
+            let mut in_mem_batches = self.in_mem_batches.lock().await;
+            in_mem_batches.push(input);
+        }
+        Ok(())
+    }
+
+    async fn spilled_before(&self) -> bool {
+        let spills = self.spills.lock().await;
+        !spills.is_empty()
+    }
+
+    /// MergeSort in mem batches as well as spills into total order with 
`SortPreservingMergeStream`.
+    async fn sort(&self) -> Result<SendableRecordBatchStream> {
+        let partition = self.partition_id();
+        let mut in_mem_batches = self.in_mem_batches.lock().await;
+
+        if self.spilled_before().await {
+            let baseline_metrics = 
self.metrics.new_intermediate_baseline(partition);
+            let mut streams: Vec<SortedStream> = vec![];
+            if in_mem_batches.len() > 0 {
+                let in_mem_stream = in_mem_partial_sort(
+                    &mut *in_mem_batches,
+                    self.schema.clone(),
+                    &self.expr,
+                    baseline_metrics,
+                )?;
+                streams.push(SortedStream::new(in_mem_stream, self.used()));
+            }
 
-/// Sort execution plan
+            let mut spills = self.spills.lock().await;
+
+            for spill in spills.drain(..) {
+                let stream = read_spill_as_stream(spill, self.schema.clone())?;
+                streams.push(SortedStream::new(stream, 0));
+            }
+            let baseline_metrics = self.metrics.new_final_baseline(partition);
+            Ok(Box::pin(SortPreservingMergeStream::new_from_streams(
+                streams,
+                self.schema.clone(),
+                &self.expr,
+                baseline_metrics,
+                partition,
+                self.runtime.clone(),
+            )))
+        } else if in_mem_batches.len() > 0 {
+            let baseline_metrics = self.metrics.new_final_baseline(partition);
+            in_mem_partial_sort(
+                &mut *in_mem_batches,
+                self.schema.clone(),
+                &self.expr,
+                baseline_metrics,
+            )
+        } else {
+            Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
+        }
+    }
+
+    fn used(&self) -> usize {
+        self.used.load(Ordering::SeqCst)
+    }
+
+    fn spilled_bytes(&self) -> usize {
+        self.spilled_bytes.load(Ordering::SeqCst)
+    }
+
+    fn spilled_count(&self) -> usize {
+        self.spilled_count.load(Ordering::SeqCst)
+    }
+}
+
+impl Debug for ExternalSorter {
+    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+        f.debug_struct("ExternalSorter")
+            .field("id", &self.id())
+            .field("memory_used", &self.used())
+            .field("spilled_bytes", &self.spilled_bytes())
+            .field("spilled_count", &self.spilled_count())
+            .finish()
+    }
+}
+
+#[async_trait]
+impl MemoryConsumer for ExternalSorter {
+    fn name(&self) -> String {
+        "ExternalSorter".to_owned()
+    }
+
+    fn id(&self) -> &MemoryConsumerId {
+        &self.id
+    }
+
+    fn memory_manager(&self) -> Arc<MemoryManager> {
+        self.runtime.memory_manager.clone()
+    }
+
+    fn type_(&self) -> &ConsumerType {
+        &ConsumerType::Requesting
+    }
+
+    async fn spill(&self) -> Result<usize> {
+        info!(
+            "{}[{}] spilling sort data of {} to disk while inserting ({} 
time(s) so far)",
+            self.name(),
+            self.id(),
+            self.used(),
+            self.spilled_count()
+        );
+
+        let partition = self.partition_id();
+        let mut in_mem_batches = self.in_mem_batches.lock().await;
+        // we could always get a chance to free some memory as long as we are 
holding some
+        if in_mem_batches.len() == 0 {
+            return Ok(0);
+        }
+
+        let baseline_metrics = 
self.metrics.new_intermediate_baseline(partition);
+
+        let path = self.runtime.disk_manager.create_tmp_file()?;
+        let stream = in_mem_partial_sort(
+            &mut *in_mem_batches,
+            self.schema.clone(),
+            &*self.expr,
+            baseline_metrics,
+        );
+
+        let total_size =
+            spill_partial_sorted_stream(&mut stream?, path.clone(), 
self.schema.clone())
+                .await?;
+
+        let mut spills = self.spills.lock().await;
+        let used = self.used.swap(0, Ordering::SeqCst);
+        self.spilled_count.fetch_add(1, Ordering::SeqCst);
+        self.spilled_bytes.fetch_add(total_size, Ordering::SeqCst);
+        spills.push(path);
+        Ok(used)
+    }
+
+    fn mem_used(&self) -> usize {
+        self.used.load(Ordering::SeqCst)
+    }
+}
+
+/// consume the non-empty `sorted_bathes` and do in_mem_sort
+fn in_mem_partial_sort(
+    buffered_batches: &mut Vec<RecordBatch>,
+    schema: SchemaRef,
+    expressions: &[PhysicalSortExpr],
+    baseline_metrics: BaselineMetrics,
+) -> Result<SendableRecordBatchStream> {
+    assert_ne!(buffered_batches.len(), 0);
+
+    let result = {
+        // NB timer records time taken on drop, so there are no
+        // calls to `timer.done()` below.
+        let _timer = baseline_metrics.elapsed_compute().timer();
+
+        let pre_sort = if buffered_batches.len() == 1 {
+            buffered_batches.pop()
+        } else {
+            let batches = buffered_batches.drain(..).collect::<Vec<_>>();
+            // combine all record batches into one for each column
+            common::combine_batches(&batches, schema.clone())?
+        };
+
+        pre_sort
+            .map(|batch| sort_batch(batch, schema.clone(), expressions))
+            .transpose()?
+    };
+
+    Ok(Box::pin(SizedRecordBatchStream::new(
+        schema,
+        vec![Arc::new(result.unwrap())],
+        baseline_metrics,
+    )))
+}
+
+async fn spill_partial_sorted_stream(
+    in_mem_stream: &mut SendableRecordBatchStream,
+    path: String,
+    schema: SchemaRef,
+) -> Result<usize> {
+    let (sender, receiver) = tokio::sync::mpsc::channel(2);
+    let path_clone = path.clone();
+    let handle = task::spawn_blocking(move || write_sorted(receiver, 
path_clone, schema));
+    while let Some(item) = in_mem_stream.next().await {
+        sender.send(item).await.ok();
+    }
+    drop(sender);
+    match handle.await {
+        Ok(r) => r,
+        Err(e) => Err(DataFusionError::Execution(format!(
+            "Error occurred while spilling {}",
+            e
+        ))),
+    }
+}
+
+fn read_spill_as_stream(
+    path: String,
+    schema: SchemaRef,
+) -> Result<SendableRecordBatchStream> {
+    let (sender, receiver): (
+        TKSender<ArrowResult<RecordBatch>>,
+        TKReceiver<ArrowResult<RecordBatch>>,
+    ) = tokio::sync::mpsc::channel(2);
+    let path_clone = path.clone();
+    let join_handle = task::spawn_blocking(move || {
+        if let Err(e) = read_spill(sender, path_clone) {
+            error!("Failure while reading spill file: {}. Error: {}", path, e);
+        }
+    });
+    Ok(RecordBatchReceiverStream::create(
+        &schema,
+        receiver,
+        join_handle,
+    ))
+}
+
+fn write_sorted(
+    mut receiver: TKReceiver<ArrowResult<RecordBatch>>,
+    path: String,
+    schema: SchemaRef,
+) -> Result<usize> {
+    let mut writer = IPCWriter::new(path.as_ref(), schema.as_ref())?;
+    while let Some(batch) = receiver.blocking_recv() {
+        writer.write(&batch?)?;
+    }
+    writer.finish()?;
+    info!(
+        "Spilled {} batches of total {} rows to disk, memory released {}",
+        writer.num_batches, writer.num_rows, writer.num_bytes
+    );
+    Ok(writer.num_bytes as usize)
+}
+
+fn read_spill(sender: TKSender<ArrowResult<RecordBatch>>, path: String) -> 
Result<()> {
+    let file = BufReader::new(File::open(&path)?);
+    let reader = FileReader::try_new(file)?;
+    for batch in reader {
+        sender
+            .blocking_send(batch)
+            .map_err(|e| DataFusionError::Execution(format!("{}", e)))?;
+    }
+    Ok(())
+}
+
+/// External Sort execution plan
 #[derive(Debug)]
 pub struct SortExec {
     /// Input schema
     input: Arc<dyn ExecutionPlan>,
     /// Sort expressions
     expr: Vec<PhysicalSortExpr>,
-    /// Execution metrics
-    metrics: ExecutionPlanMetricsSet,
+    /// Containing all metrics set created during sort
+    all_metrics: AggregatedMetricsSet,
     /// Preserve partitions of input plan
     preserve_partitioning: bool,
 }
 
+#[derive(Debug, Clone)]
+/// Aggregates all metrics during a complex operation, which is composed of 
multiple stages and
+/// each stage reports its statistics separately.
+/// Give sort as an example, when the dataset is more significant than 
available memory, it will report
+/// multiple in-mem sort metrics and final merge-sort  metrics from 
`SortPreservingMergeStream`.
+/// Therefore, We need a separation of metrics for which are final metrics 
(for output_rows accumulation),
+/// and which are intermediate metrics that we only account for 
elapsed_compute time.
+struct AggregatedMetricsSet {
+    intermediate: Arc<std::sync::Mutex<Vec<ExecutionPlanMetricsSet>>>,
+    final_: Arc<std::sync::Mutex<Vec<ExecutionPlanMetricsSet>>>,
+}
+
+impl AggregatedMetricsSet {
+    fn new() -> Self {
+        Self {
+            intermediate: Arc::new(std::sync::Mutex::new(vec![])),
+            final_: Arc::new(std::sync::Mutex::new(vec![])),
+        }
+    }
+
+    fn new_intermediate_baseline(&self, partition: usize) -> BaselineMetrics {
+        let ms = ExecutionPlanMetricsSet::new();
+        let result = BaselineMetrics::new(&ms, partition);
+        self.intermediate.lock().unwrap().push(ms);
+        result
+    }
+
+    fn new_final_baseline(&self, partition: usize) -> BaselineMetrics {
+        let ms = ExecutionPlanMetricsSet::new();
+        let result = BaselineMetrics::new(&ms, partition);
+        self.final_.lock().unwrap().push(ms);
+        result
+    }
+
+    /// We should accumulate all times from all stages' reports for the total 
time consumption.
+    fn merge_compute_time(&self, dest: &Time) {
+        let time1 = self
+            .intermediate
+            .lock()
+            .unwrap()
+            .iter()
+            .map(|es| {
+                es.clone_inner()
+                    .elapsed_compute()
+                    .map_or(0u64, |v| v as u64)
+            })
+            .sum();
+        let time2 = self
+            .final_
+            .lock()
+            .unwrap()
+            .iter()
+            .map(|es| {
+                es.clone_inner()
+                    .elapsed_compute()
+                    .map_or(0u64, |v| v as u64)
+            })
+            .sum();
+        dest.add_duration(Duration::from_nanos(time1));
+        dest.add_duration(Duration::from_nanos(time2));
+    }
+
+    /// We should only care about output from the final stage metrics.
+    fn merge_output_count(&self, dest: &Count) {
+        let count = self
+            .final_
+            .lock()
+            .unwrap()
+            .iter()
+            .map(|es| es.clone_inner().output_rows().map_or(0, |v| v))
+            .sum();
+        dest.add(count);
+    }
+}
+
 impl SortExec {
     /// Create a new sort execution plan
     pub fn try_new(
@@ -75,7 +461,7 @@ impl SortExec {
         Self {
             expr,
             input,
-            metrics: ExecutionPlanMetricsSet::new(),
+            all_metrics: AggregatedMetricsSet::new(),
             preserve_partitioning,
         }
     }
@@ -93,7 +479,6 @@ impl SortExec {
 
 #[async_trait]
 impl ExecutionPlan for SortExec {
-    /// Return a reference to Any that can be used for downcasting
     fn as_any(&self) -> &dyn Any {
         self
     }
@@ -102,10 +487,6 @@ impl ExecutionPlan for SortExec {
         self.input.schema()
     }
 
-    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
-        vec![self.input.clone()]
-    }
-
     /// Get the output partitioning of this plan
     fn output_partitioning(&self) -> Partitioning {
         if self.preserve_partitioning {
@@ -123,6 +504,10 @@ impl ExecutionPlan for SortExec {
         }
     }
 
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        vec![self.input.clone()]
+    }
+
     fn with_new_children(
         &self,
         children: Vec<Arc<dyn ExecutionPlan>>,
@@ -159,14 +544,25 @@ impl ExecutionPlan for SortExec {
             }
         }
 
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
-        let input = self.input.execute(partition, runtime).await?;
+        let input = self.input.execute(partition, runtime.clone()).await?;
 
-        Ok(Box::pin(SortStream::new(
+        do_sort(
             input,
+            partition,
             self.expr.clone(),
-            baseline_metrics,
-        )))
+            self.all_metrics.clone(),
+            runtime,
+        )
+        .await
+    }
+
+    fn metrics(&self) -> Option<MetricsSet> {
+        let metrics = ExecutionPlanMetricsSet::new();
+        let baseline = BaselineMetrics::new(&metrics, 0);
+        self.all_metrics
+            .merge_compute_time(baseline.elapsed_compute());
+        self.all_metrics.merge_output_count(baseline.output_rows());
+        Some(metrics.clone_inner())
     }
 
     fn fmt_as(
@@ -182,16 +578,12 @@ impl ExecutionPlan for SortExec {
         }
     }
 
-    fn metrics(&self) -> Option<MetricsSet> {
-        Some(self.metrics.clone_inner())
-    }
-
     fn statistics(&self) -> Statistics {
         self.input.statistics()
     }
 }
 
-pub(crate) fn sort_batch(
+fn sort_batch(
     batch: RecordBatch,
     schema: SchemaRef,
     expr: &[PhysicalSortExpr],
@@ -227,97 +619,38 @@ pub(crate) fn sort_batch(
     )
 }
 
-pin_project! {
-    /// stream for sort plan
-    struct SortStream {
-        #[pin]
-        output: 
futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
-        finished: bool,
-        schema: SchemaRef,
-        drop_helper: AbortOnDropSingle<()>,
-    }
-}
-
-impl SortStream {
-    fn new(
-        input: SendableRecordBatchStream,
-        expr: Vec<PhysicalSortExpr>,
-        baseline_metrics: BaselineMetrics,
-    ) -> Self {
-        let (tx, rx) = futures::channel::oneshot::channel();
-        let schema = input.schema();
-        let join_handle = tokio::spawn(async move {
-            let schema = input.schema();
-            let sorted_batch = common::collect(input)
-                .await
-                .map_err(DataFusionError::into_arrow_external_error)
-                .and_then(move |batches| {
-                    let timer = baseline_metrics.elapsed_compute().timer();
-                    // combine all record batches into one for each column
-                    let combined = common::combine_batches(&batches, 
schema.clone())?;
-                    // sort combined record batch
-                    let result = combined
-                        .map(|batch| sort_batch(batch, schema, &expr))
-                        .transpose()?
-                        .record_output(&baseline_metrics);
-                    timer.done();
-                    Ok(result)
-                });
-
-            // failing here is OK, the receiver is gone and does not care 
about the result
-            tx.send(sorted_batch).ok();
-        });
-
-        Self {
-            output: rx,
-            finished: false,
-            schema,
-            drop_helper: AbortOnDropSingle::new(join_handle),
-        }
-    }
-}
-
-impl Stream for SortStream {
-    type Item = ArrowResult<RecordBatch>;
-
-    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Option<Self::Item>> {
-        if self.finished {
-            return Poll::Ready(None);
-        }
-
-        // is the output ready?
-        let this = self.project();
-        let output_poll = this.output.poll(cx);
-
-        match output_poll {
-            Poll::Ready(result) => {
-                *this.finished = true;
-
-                // check for error in receiving channel and unwrap actual 
result
-                let result = match result {
-                    Err(e) => 
Some(Err(ArrowError::ExternalError(Box::new(e)))), // error receiving
-                    Ok(result) => result.transpose(),
-                };
-
-                Poll::Ready(result)
-            }
-            Poll::Pending => Poll::Pending,
-        }
+async fn do_sort(
+    mut input: SendableRecordBatchStream,
+    partition_id: usize,
+    expr: Vec<PhysicalSortExpr>,
+    metrics: AggregatedMetricsSet,
+    runtime: Arc<RuntimeEnv>,
+) -> Result<SendableRecordBatchStream> {
+    let schema = input.schema();
+    let sorter = Arc::new(ExternalSorter::new(
+        partition_id,
+        schema.clone(),
+        expr,
+        metrics,
+        runtime.clone(),
+    ));
+    runtime.register_consumer(&(sorter.clone() as Arc<dyn MemoryConsumer>));
+
+    while let Some(batch) = input.next().await {
+        let batch = batch?;
+        sorter.insert_batch(batch).await?;
     }
-}
 
-impl RecordBatchStream for SortStream {
-    fn schema(&self) -> SchemaRef {
-        self.schema.clone()
-    }
+    let result = sorter.sort().await;
+    runtime.drop_consumer(sorter.id());
+    result
 }
 
 #[cfg(test)]
 mod tests {
-    use std::collections::{BTreeMap, HashMap};
-
     use super::*;
     use crate::datasource::object_store::local::LocalFileSystem;
+    use crate::execution::runtime_env::RuntimeConfig;
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
     use crate::physical_plan::expressions::col;
     use crate::physical_plan::memory::MemoryExec;
@@ -325,17 +658,17 @@ mod tests {
         collect,
         file_format::{CsvExec, FileScanConfig},
     };
+    use crate::test;
     use crate::test::assert_is_pending;
-    use crate::test::exec::assert_strong_count_converges_to_zero;
-    use crate::test::{self, exec::BlockingExec};
+    use crate::test::exec::{assert_strong_count_converges_to_zero, 
BlockingExec};
     use crate::test_util;
     use arrow::array::*;
+    use arrow::compute::SortOptions;
     use arrow::datatypes::*;
     use futures::FutureExt;
+    use std::collections::{BTreeMap, HashMap};
 
-    #[tokio::test]
-    async fn test_sort() -> Result<()> {
-        let runtime = Arc::new(RuntimeEnv::default());
+    async fn sort_with_runtime(runtime: Arc<RuntimeEnv>) -> 
Result<Vec<RecordBatch>> {
         let schema = test_util::aggr_test_schema();
         let partitions = 4;
         let (_, files) =
@@ -376,7 +709,42 @@ mod tests {
             Arc::new(CoalescePartitionsExec::new(Arc::new(csv))),
         )?);
 
-        let result: Vec<RecordBatch> = collect(sort_exec, runtime).await?;
+        collect(sort_exec, runtime).await
+    }
+
+    #[tokio::test]
+    async fn test_in_mem_sort() -> Result<()> {
+        let runtime = Arc::new(RuntimeEnv::default());
+        let result = sort_with_runtime(runtime).await?;
+
+        assert_eq!(result.len(), 1);
+
+        let columns = result[0].columns();
+
+        let c1 = as_string_array(&columns[0]);
+        assert_eq!(c1.value(0), "a");
+        assert_eq!(c1.value(c1.len() - 1), "e");
+
+        let c2 = as_primitive_array::<UInt32Type>(&columns[1]);
+        assert_eq!(c2.value(0), 1);
+        assert_eq!(c2.value(c2.len() - 1), 5,);
+
+        let c7 = as_primitive_array::<UInt8Type>(&columns[6]);
+        assert_eq!(c7.value(0), 15);
+        assert_eq!(c7.value(c7.len() - 1), 254,);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_sort_spill() -> Result<()> {
+        let config = RuntimeConfig::new()
+            .with_memory_fraction(1.0)
+            // trigger spill there will be 4 batches with 5.5KB for each
+            .with_max_execution_memory(12288);
+        let runtime = Arc::new(RuntimeEnv::new(config)?);
+        let result = sort_with_runtime(runtime).await?;
+
         assert_eq!(result.len(), 1);
 
         let columns = result[0].columns();
diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index 9f12891..189a9fb 100644
--- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -22,8 +22,7 @@ use crate::physical_plan::metrics::{
     BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
 };
 use std::any::Any;
-use std::cmp::Ordering;
-use std::collections::VecDeque;
+use std::collections::{BinaryHeap, VecDeque};
 use std::fmt::{Debug, Formatter};
 use std::pin::Pin;
 use std::sync::{Arc, Mutex};
@@ -168,18 +167,15 @@ impl ExecutionPlan for SortPreservingMergeExec {
                     })
                     .unzip();
 
-                Ok(Box::pin(
-                    SortPreservingMergeStream::new_from_receiver(
-                        receivers,
-                        AbortOnDropMany(join_handles),
-                        self.schema(),
-                        &self.expr,
-                        baseline_metrics,
-                        partition,
-                        runtime.clone(),
-                    )
-                    .await,
-                ))
+                Ok(Box::pin(SortPreservingMergeStream::new_from_receivers(
+                    receivers,
+                    AbortOnDropMany(join_handles),
+                    self.schema(),
+                    &self.expr,
+                    baseline_metrics,
+                    partition,
+                    runtime,
+                )))
             }
         }
     }
@@ -284,7 +280,7 @@ pub(crate) struct SortPreservingMergeStream {
     ///
     /// Exhausted cursors will be popped off the front once all
     /// their rows have been yielded to the output
-    cursors: Vec<VecDeque<SortKeyCursor>>,
+    cursors: Vec<VecDeque<Arc<SortKeyCursor>>>,
 
     /// The accumulated row indexes for the next record batch
     in_progress: Vec<RowIndex>,
@@ -304,6 +300,9 @@ pub(crate) struct SortPreservingMergeStream {
     /// An index to uniquely identify the input stream batch
     next_batch_index: usize,
 
+    /// min heap for record comparison
+    min_heap: BinaryHeap<Arc<SortKeyCursor>>,
+
     /// runtime
     runtime: Arc<RuntimeEnv>,
 }
@@ -316,7 +315,7 @@ impl Drop for SortPreservingMergeStream {
 
 impl SortPreservingMergeStream {
     #[allow(clippy::too_many_arguments)]
-    pub(crate) async fn new_from_receiver(
+    pub(crate) fn new_from_receivers(
         receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
         _drop_helper: AbortOnDropMany<()>,
         schema: SchemaRef,
@@ -325,16 +324,16 @@ impl SortPreservingMergeStream {
         partition: usize,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
-        let cursors = (0..receivers.len())
+        let stream_count = receivers.len();
+        let cursors = (0..stream_count)
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
-
         let wrappers = 
receivers.into_iter().map(StreamWrapper::Receiver).collect();
         let streams = Arc::new(MergingStreams::new(partition, wrappers, 
runtime.clone()));
         runtime.register_consumer(&(streams.clone() as Arc<dyn 
MemoryConsumer>));
 
-        Self {
+        SortPreservingMergeStream {
             schema,
             cursors,
             streams,
@@ -345,11 +344,12 @@ impl SortPreservingMergeStream {
             aborted: false,
             in_progress: vec![],
             next_batch_index: 0,
+            min_heap: BinaryHeap::with_capacity(stream_count),
             runtime,
         }
     }
 
-    pub(crate) async fn new_from_stream(
+    pub(crate) fn new_from_streams(
         streams: Vec<SortedStream>,
         schema: SchemaRef,
         expressions: &[PhysicalSortExpr],
@@ -357,16 +357,15 @@ impl SortPreservingMergeStream {
         partition: usize,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
-        let cursors = (0..streams.len())
+        let stream_count = streams.len();
+        let cursors = (0..stream_count)
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
-
         let wrappers = streams
             .into_iter()
             .map(|s| StreamWrapper::Stream(Some(s)))
-            .collect::<Vec<_>>();
-
+            .collect();
         let streams = Arc::new(MergingStreams::new(partition, wrappers, 
runtime.clone()));
         runtime.register_consumer(&(streams.clone() as Arc<dyn 
MemoryConsumer>));
 
@@ -381,6 +380,7 @@ impl SortPreservingMergeStream {
             aborted: false,
             in_progress: vec![],
             next_batch_index: 0,
+            min_heap: BinaryHeap::with_capacity(stream_count),
             runtime,
         }
     }
@@ -414,18 +414,24 @@ impl SortPreservingMergeStream {
                 return Poll::Ready(Err(e));
             }
             Some(Ok(batch)) => {
-                let cursor = match SortKeyCursor::new(
-                    self.next_batch_index, // assign this batch an ID
-                    Arc::new(batch),
-                    &self.column_expressions,
-                    self.sort_options.clone(),
-                ) {
-                    Ok(cursor) => cursor,
-                    Err(e) => {
-                        return 
Poll::Ready(Err(ArrowError::ExternalError(Box::new(e))));
-                    }
-                };
+                let cursor = Arc::new(
+                    match SortKeyCursor::new(
+                        idx,
+                        self.next_batch_index, // assign this batch an ID
+                        Arc::new(batch),
+                        &self.column_expressions,
+                        self.sort_options.clone(),
+                    ) {
+                        Ok(cursor) => cursor,
+                        Err(e) => {
+                            return Poll::Ready(Err(ArrowError::ExternalError(
+                                Box::new(e),
+                            )));
+                        }
+                    },
+                );
                 self.next_batch_index += 1;
+                self.min_heap.push(cursor.clone());
                 self.cursors[idx].push_back(cursor)
             }
         }
@@ -433,30 +439,6 @@ impl SortPreservingMergeStream {
         Poll::Ready(Ok(()))
     }
 
-    /// Returns the index of the next stream to pull a row from, or None
-    /// if all cursors for all streams are exhausted
-    fn next_stream_idx(&mut self) -> Result<Option<usize>> {
-        let mut min_cursor: Option<(usize, &mut SortKeyCursor)> = None;
-        for (idx, candidate) in self.cursors.iter_mut().enumerate() {
-            if let Some(candidate) = candidate.back_mut() {
-                if candidate.is_finished() {
-                    continue;
-                }
-
-                match min_cursor {
-                    None => min_cursor = Some((idx, candidate)),
-                    Some((_, ref mut min)) => {
-                        if min.compare(candidate)? == Ordering::Greater {
-                            min_cursor = Some((idx, candidate))
-                        }
-                    }
-                }
-            }
-        }
-
-        Ok(min_cursor.map(|(idx, _)| idx))
-    }
-
     /// Drains the in_progress row indexes, and builds a new RecordBatch from 
them
     ///
     /// Will then drop any cursors for which all rows have been yielded to the 
output
@@ -588,44 +570,44 @@ impl SortPreservingMergeStream {
             let elapsed_compute = 
self.baseline_metrics.elapsed_compute().clone();
             let _timer = elapsed_compute.timer();
 
-            let stream_idx = match self.next_stream_idx() {
-                Ok(Some(idx)) => idx,
-                Ok(None) if self.in_progress.is_empty() => return 
Poll::Ready(None),
-                Ok(None) => return 
Poll::Ready(Some(self.build_record_batch())),
-                Err(e) => {
-                    self.aborted = true;
-                    return 
Poll::Ready(Some(Err(ArrowError::ExternalError(Box::new(
-                        e,
-                    )))));
-                }
-            };
-
-            let cursors = &mut self.cursors[stream_idx];
-            let cursor_idx = cursors.len() - 1;
-            let cursor = cursors.back_mut().unwrap();
-            let row_idx = cursor.advance();
-            let cursor_finished = cursor.is_finished();
-
-            self.in_progress.push(RowIndex {
-                stream_idx,
-                cursor_idx,
-                row_idx,
-            });
+            match self.min_heap.pop() {
+                Some(cursor) => {
+                    let stream_idx = cursor.stream_idx;
+                    let cursor_idx = self.cursors[stream_idx].len() - 1;
+                    let row_idx = cursor.advance();
+
+                    let mut cursor_finished = false;
+                    // insert the cursor back to min_heap if the record batch 
is not exhausted
+                    if !cursor.is_finished() {
+                        self.min_heap.push(cursor);
+                    } else {
+                        cursor_finished = true;
+                    }
 
-            if self.in_progress.len() == self.runtime.batch_size() {
-                return Poll::Ready(Some(self.build_record_batch()));
-            }
+                    self.in_progress.push(RowIndex {
+                        stream_idx,
+                        cursor_idx,
+                        row_idx,
+                    });
 
-            // If removed the last row from the cursor, need to fetch a new 
record
-            // batch if possible, before looping round again
-            if cursor_finished {
-                match futures::ready!(self.maybe_poll_stream(cx, stream_idx)) {
-                    Ok(_) => {}
-                    Err(e) => {
-                        self.aborted = true;
-                        return Poll::Ready(Some(Err(e)));
+                    if self.in_progress.len() == self.runtime.batch_size() {
+                        return Poll::Ready(Some(self.build_record_batch()));
+                    }
+
+                    // If removed the last row from the cursor, need to fetch 
a new record
+                    // batch if possible, before looping round again
+                    if cursor_finished {
+                        match futures::ready!(self.maybe_poll_stream(cx, 
stream_idx)) {
+                            Ok(_) => {}
+                            Err(e) => {
+                                self.aborted = true;
+                                return Poll::Ready(Some(Err(e)));
+                            }
+                        }
                     }
                 }
+                None if self.in_progress.is_empty() => return 
Poll::Ready(None),
+                None => return Poll::Ready(Some(self.build_record_batch())),
             }
         }
     }
@@ -1089,8 +1071,6 @@ mod tests {
 
     #[tokio::test]
     async fn test_partition_sort_streaming_input_output() {
-        let runtime =
-            
Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(23)).unwrap());
         let schema = test_util::aggr_test_schema();
 
         let sort = vec![
@@ -1106,12 +1086,15 @@ mod tests {
             },
         ];
 
+        let runtime = Arc::new(RuntimeEnv::default());
         let input =
             sorted_partitioned_input(sort.clone(), &[10, 5, 13], 
runtime.clone()).await;
-        let basic = basic_sort(input.clone(), sort.clone(), 
runtime.clone()).await;
+        let basic = basic_sort(input.clone(), sort.clone(), runtime).await;
 
+        let runtime_bs_23 =
+            
Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(23)).unwrap());
         let merge = Arc::new(SortPreservingMergeExec::new(sort, input));
-        let merged = collect(merge, runtime.clone()).await.unwrap();
+        let merged = collect(merge, runtime_bs_23).await.unwrap();
 
         assert_eq!(merged.len(), 14);
 
@@ -1242,7 +1225,7 @@ mod tests {
         let metrics = ExecutionPlanMetricsSet::new();
         let baseline_metrics = BaselineMetrics::new(&metrics, 0);
 
-        let merge_stream = SortPreservingMergeStream::new_from_receiver(
+        let merge_stream = SortPreservingMergeStream::new_from_receivers(
             receivers,
             // Use empty vector since we want to use the join handles ourselves
             AbortOnDropMany(vec![]),
@@ -1251,8 +1234,7 @@ mod tests {
             baseline_metrics,
             0,
             runtime.clone(),
-        )
-        .await;
+        );
 
         let mut merged = 
common::collect(Box::pin(merge_stream)).await.unwrap();
 
diff --git a/datafusion/tests/provider_filter_pushdown.rs 
b/datafusion/tests/provider_filter_pushdown.rs
index 5e14524..5a4f907 100644
--- a/datafusion/tests/provider_filter_pushdown.rs
+++ b/datafusion/tests/provider_filter_pushdown.rs
@@ -25,6 +25,7 @@ use datafusion::execution::context::ExecutionContext;
 use datafusion::execution::runtime_env::RuntimeEnv;
 use datafusion::logical_plan::Expr;
 use datafusion::physical_plan::common::SizedRecordBatchStream;
+use datafusion::physical_plan::metrics::{BaselineMetrics, 
ExecutionPlanMetricsSet};
 use datafusion::physical_plan::{
     DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
 };
@@ -81,12 +82,15 @@ impl ExecutionPlan for CustomPlan {
 
     async fn execute(
         &self,
-        _partition: usize,
+        partition: usize,
         _runtime: Arc<RuntimeEnv>,
     ) -> Result<SendableRecordBatchStream> {
+        let metrics = ExecutionPlanMetricsSet::new();
+        let baseline_metrics = BaselineMetrics::new(&metrics, partition);
         Ok(Box::pin(SizedRecordBatchStream::new(
             self.schema(),
             self.batches.clone(),
+            baseline_metrics,
         )))
     }
 
diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs
index 1aa4eb6..85b59e6 100644
--- a/datafusion/tests/sql/joins.rs
+++ b/datafusion/tests/sql/joins.rs
@@ -419,32 +419,32 @@ async fn cross_join_unbalanced() {
 
     // the order of the values is not determinisitic, so we need to sort to 
check the values
     let sql =
-        "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, 
t1_name";
+        "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, 
t1_name, t2_name";
     let actual = execute_to_batches(&mut ctx, sql).await;
     let expected = vec![
         "+-------+---------+---------+",
         "| t1_id | t1_name | t2_name |",
         "+-------+---------+---------+",
-        "| 11    | a       | z       |",
-        "| 11    | a       | y       |",
-        "| 11    | a       | x       |",
         "| 11    | a       | w       |",
-        "| 22    | b       | z       |",
-        "| 22    | b       | y       |",
-        "| 22    | b       | x       |",
+        "| 11    | a       | x       |",
+        "| 11    | a       | y       |",
+        "| 11    | a       | z       |",
         "| 22    | b       | w       |",
-        "| 33    | c       | z       |",
-        "| 33    | c       | y       |",
-        "| 33    | c       | x       |",
+        "| 22    | b       | x       |",
+        "| 22    | b       | y       |",
+        "| 22    | b       | z       |",
         "| 33    | c       | w       |",
-        "| 44    | d       | z       |",
-        "| 44    | d       | y       |",
-        "| 44    | d       | x       |",
+        "| 33    | c       | x       |",
+        "| 33    | c       | y       |",
+        "| 33    | c       | z       |",
         "| 44    | d       | w       |",
-        "| 77    | e       | z       |",
-        "| 77    | e       | y       |",
-        "| 77    | e       | x       |",
+        "| 44    | d       | x       |",
+        "| 44    | d       | y       |",
+        "| 44    | d       | z       |",
         "| 77    | e       | w       |",
+        "| 77    | e       | x       |",
+        "| 77    | e       | y       |",
+        "| 77    | e       | z       |",
         "+-------+---------+---------+",
     ];
     assert_batches_eq!(expected, &actual);

Reply via email to