This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 641338f  Add `MemTrackingMetrics` to ease memory tracking for 
non-limited memory consumers (#1691)
641338f is described below

commit 641338f726549c10c5bafee34537dc1e56cdec04
Author: Yijie Shen <[email protected]>
AuthorDate: Sun Jan 30 00:02:49 2022 +0800

    Add `MemTrackingMetrics` to ease memory tracking for non-limited memory 
consumers (#1691)
    
    * Memory manager no longer track consumers, update aggregatedMetricsSet
    
    * Easy memory tracking with metrics
    
    * use tracking metrics in SPMS
    
    * tests
    
    * fix
    
    * doc
    
    * Update datafusion/src/physical_plan/sorts/sort.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    * make tracker AtomicUsize
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/src/execution/memory_manager.rs         | 134 +++++++-------
 datafusion/src/execution/runtime_env.rs            |  24 ++-
 datafusion/src/physical_plan/common.rs             |  12 +-
 datafusion/src/physical_plan/explain.rs            |   6 +-
 datafusion/src/physical_plan/metrics/aggregated.rs | 155 ----------------
 datafusion/src/physical_plan/metrics/baseline.rs   |  14 +-
 datafusion/src/physical_plan/metrics/composite.rs  | 205 +++++++++++++++++++++
 datafusion/src/physical_plan/metrics/mod.rs        |   6 +-
 datafusion/src/physical_plan/metrics/tracker.rs    | 131 +++++++++++++
 datafusion/src/physical_plan/sorts/mod.rs          |   9 -
 datafusion/src/physical_plan/sorts/sort.rs         |  95 +++++-----
 .../physical_plan/sorts/sort_preserving_merge.rs   | 114 +++---------
 datafusion/tests/provider_filter_pushdown.rs       |   6 +-
 13 files changed, 525 insertions(+), 386 deletions(-)

diff --git a/datafusion/src/execution/memory_manager.rs 
b/datafusion/src/execution/memory_manager.rs
index 53eb720..0fb3cfb 100644
--- a/datafusion/src/execution/memory_manager.rs
+++ b/datafusion/src/execution/memory_manager.rs
@@ -19,12 +19,12 @@
 
 use crate::error::{DataFusionError, Result};
 use async_trait::async_trait;
-use hashbrown::HashMap;
+use hashbrown::HashSet;
 use log::debug;
 use std::fmt;
 use std::fmt::{Debug, Display, Formatter};
 use std::sync::atomic::{AtomicUsize, Ordering};
-use std::sync::{Arc, Condvar, Mutex, Weak};
+use std::sync::{Arc, Condvar, Mutex};
 
 static CONSUMER_ID: AtomicUsize = AtomicUsize::new(0);
 
@@ -245,10 +245,10 @@ The memory management architecture is the following:
 /// Manage memory usage during physical plan execution
 #[derive(Debug)]
 pub struct MemoryManager {
-    requesters: Arc<Mutex<HashMap<MemoryConsumerId, Weak<dyn 
MemoryConsumer>>>>,
-    trackers: Arc<Mutex<HashMap<MemoryConsumerId, Weak<dyn MemoryConsumer>>>>,
+    requesters: Arc<Mutex<HashSet<MemoryConsumerId>>>,
     pool_size: usize,
     requesters_total: Arc<Mutex<usize>>,
+    trackers_total: AtomicUsize,
     cv: Condvar,
 }
 
@@ -267,10 +267,10 @@ impl MemoryManager {
                 );
 
                 Arc::new(Self {
-                    requesters: Arc::new(Mutex::new(HashMap::new())),
-                    trackers: Arc::new(Mutex::new(HashMap::new())),
+                    requesters: Arc::new(Mutex::new(HashSet::new())),
                     pool_size,
                     requesters_total: Arc::new(Mutex::new(0)),
+                    trackers_total: AtomicUsize::new(0),
                     cv: Condvar::new(),
                 })
             }
@@ -278,30 +278,36 @@ impl MemoryManager {
     }
 
     fn get_tracker_total(&self) -> usize {
-        let trackers = self.trackers.lock().unwrap();
-        if trackers.len() > 0 {
-            trackers.values().fold(0usize, |acc, y| match y.upgrade() {
-                None => acc,
-                Some(t) => acc + t.mem_used(),
-            })
-        } else {
-            0
-        }
+        self.trackers_total.load(Ordering::SeqCst)
     }
 
-    /// Register a new memory consumer for memory usage tracking
-    pub(crate) fn register_consumer(&self, consumer: &Arc<dyn MemoryConsumer>) 
{
-        let id = consumer.id().clone();
-        match consumer.type_() {
-            ConsumerType::Requesting => {
-                let mut requesters = self.requesters.lock().unwrap();
-                requesters.insert(id, Arc::downgrade(consumer));
-            }
-            ConsumerType::Tracking => {
-                let mut trackers = self.trackers.lock().unwrap();
-                trackers.insert(id, Arc::downgrade(consumer));
-            }
-        }
+    pub(crate) fn grow_tracker_usage(&self, delta: usize) {
+        self.trackers_total.fetch_add(delta, Ordering::SeqCst);
+    }
+
+    pub(crate) fn shrink_tracker_usage(&self, delta: usize) {
+        let update =
+            self.trackers_total
+                .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| {
+                    if x >= delta {
+                        Some(x - delta)
+                    } else {
+                        None
+                    }
+                });
+        update.expect(&*format!(
+            "Tracker total memory shrink by {} underflow, current value is ",
+            delta
+        ));
+    }
+
+    fn get_requester_total(&self) -> usize {
+        *self.requesters_total.lock().unwrap()
+    }
+
+    /// Register a new memory requester
+    pub(crate) fn register_requester(&self, requester_id: &MemoryConsumerId) {
+        self.requesters.lock().unwrap().insert(requester_id.clone());
     }
 
     fn max_mem_for_requesters(&self) -> usize {
@@ -317,7 +323,6 @@ impl MemoryManager {
 
         let granted;
         loop {
-            let remaining = rqt_max - *rqt_current_used;
             let max_per_rqt = rqt_max / num_rqt;
             let min_per_rqt = max_per_rqt / 2;
 
@@ -326,6 +331,7 @@ impl MemoryManager {
                 break;
             }
 
+            let remaining = 
rqt_max.checked_sub(*rqt_current_used).unwrap_or_default();
             if remaining >= required {
                 granted = true;
                 *rqt_current_used += required;
@@ -347,46 +353,37 @@ impl MemoryManager {
 
     fn record_free_then_acquire(&self, freed: usize, acquired: usize) {
         let mut requesters_total = self.requesters_total.lock().unwrap();
+        assert!(*requesters_total >= freed);
         *requesters_total -= freed;
         *requesters_total += acquired;
         self.cv.notify_all()
     }
 
-    /// Drop a memory consumer from memory usage tracking
-    pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId) {
+    /// Drop a memory consumer and reclaim the memory
+    pub(crate) fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) 
{
         // find in requesters first
         {
             let mut requesters = self.requesters.lock().unwrap();
-            if requesters.remove(id).is_some() {
-                return;
+            if requesters.remove(id) {
+                let mut total = self.requesters_total.lock().unwrap();
+                assert!(*total >= mem_used);
+                *total -= mem_used;
             }
         }
-        let mut trackers = self.trackers.lock().unwrap();
-        trackers.remove(id);
+        self.shrink_tracker_usage(mem_used);
+        self.cv.notify_all();
     }
 }
 
 impl Display for MemoryManager {
     fn fmt(&self, f: &mut Formatter) -> fmt::Result {
-        let requesters =
-            self.requesters
-                .lock()
-                .unwrap()
-                .values()
-                .fold(vec![], |mut acc, consumer| match consumer.upgrade() {
-                    None => acc,
-                    Some(c) => {
-                        acc.push(format!("{}", c));
-                        acc
-                    }
-                });
-        let tracker_mem = self.get_tracker_total();
         write!(f,
-               "MemoryManager usage statistics: total {}, tracker used {}, 
total {} requesters detail: \n {},",
-                human_readable_size(self.pool_size),
-                human_readable_size(tracker_mem),
-                &requesters.len(),
-               requesters.join("\n"))
+               "MemoryManager usage statistics: total {}, trackers used {}, 
total {} requesters used: {}",
+               human_readable_size(self.pool_size),
+               human_readable_size(self.get_tracker_total()),
+               self.requesters.lock().unwrap().len(),
+               human_readable_size(self.get_requester_total()),
+        )
     }
 }
 
@@ -418,6 +415,8 @@ mod tests {
     use super::*;
     use crate::error::Result;
     use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+    use crate::execution::MemoryConsumer;
+    use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, 
MemTrackingMetrics};
     use async_trait::async_trait;
     use std::sync::atomic::{AtomicUsize, Ordering};
     use std::sync::Arc;
@@ -487,6 +486,7 @@ mod tests {
 
     impl DummyTracker {
         fn new(partition: usize, runtime: Arc<RuntimeEnv>, mem_used: usize) -> 
Self {
+            runtime.grow_tracker_usage(mem_used);
             Self {
                 id: MemoryConsumerId::new(partition),
                 runtime,
@@ -528,23 +528,29 @@ mod tests {
             .with_memory_manager(MemoryManagerConfig::try_new_limit(100, 
1.0).unwrap());
         let runtime = Arc::new(RuntimeEnv::new(config).unwrap());
 
-        let tracker1 = Arc::new(DummyTracker::new(0, runtime.clone(), 5));
-        runtime.register_consumer(&(tracker1.clone() as Arc<dyn 
MemoryConsumer>));
+        DummyTracker::new(0, runtime.clone(), 5);
         assert_eq!(runtime.memory_manager.get_tracker_total(), 5);
 
-        let tracker2 = Arc::new(DummyTracker::new(0, runtime.clone(), 10));
-        runtime.register_consumer(&(tracker2.clone() as Arc<dyn 
MemoryConsumer>));
+        let tracker1 = DummyTracker::new(0, runtime.clone(), 10);
         assert_eq!(runtime.memory_manager.get_tracker_total(), 15);
 
-        let tracker3 = Arc::new(DummyTracker::new(0, runtime.clone(), 15));
-        runtime.register_consumer(&(tracker3.clone() as Arc<dyn 
MemoryConsumer>));
+        DummyTracker::new(0, runtime.clone(), 15);
         assert_eq!(runtime.memory_manager.get_tracker_total(), 30);
 
-        runtime.drop_consumer(tracker2.id());
+        runtime.drop_consumer(tracker1.id(), tracker1.mem_used);
+        assert_eq!(runtime.memory_manager.get_tracker_total(), 20);
+
+        // MemTrackingMetrics as an easy way to track memory
+        let ms = ExecutionPlanMetricsSet::new();
+        let tracking_metric = MemTrackingMetrics::new_with_rt(&ms, 0, 
runtime.clone());
+        tracking_metric.init_mem_used(15);
+        assert_eq!(runtime.memory_manager.get_tracker_total(), 35);
+
+        drop(tracking_metric);
         assert_eq!(runtime.memory_manager.get_tracker_total(), 20);
 
-        let requester1 = Arc::new(DummyRequester::new(0, runtime.clone()));
-        runtime.register_consumer(&(requester1.clone() as Arc<dyn 
MemoryConsumer>));
+        let requester1 = DummyRequester::new(0, runtime.clone());
+        runtime.register_requester(requester1.id());
 
         // first requester entered, should be able to use any of the remaining 
80
         requester1.do_with_mem(40).await.unwrap();
@@ -553,8 +559,8 @@ mod tests {
         assert_eq!(requester1.mem_used(), 50);
         assert_eq!(*runtime.memory_manager.requesters_total.lock().unwrap(), 
50);
 
-        let requester2 = Arc::new(DummyRequester::new(0, runtime.clone()));
-        runtime.register_consumer(&(requester2.clone() as Arc<dyn 
MemoryConsumer>));
+        let requester2 = DummyRequester::new(0, runtime.clone());
+        runtime.register_requester(requester2.id());
 
         requester2.do_with_mem(20).await.unwrap();
         requester2.do_with_mem(30).await.unwrap();
diff --git a/datafusion/src/execution/runtime_env.rs 
b/datafusion/src/execution/runtime_env.rs
index cdcd1f7..e993b38 100644
--- a/datafusion/src/execution/runtime_env.rs
+++ b/datafusion/src/execution/runtime_env.rs
@@ -22,9 +22,7 @@ use crate::{
     error::Result,
     execution::{
         disk_manager::{DiskManager, DiskManagerConfig},
-        memory_manager::{
-            MemoryConsumer, MemoryConsumerId, MemoryManager, 
MemoryManagerConfig,
-        },
+        memory_manager::{MemoryConsumerId, MemoryManager, MemoryManagerConfig},
     },
 };
 
@@ -71,13 +69,23 @@ impl RuntimeEnv {
     }
 
     /// Register the consumer to get it tracked
-    pub fn register_consumer(&self, memory_consumer: &Arc<dyn MemoryConsumer>) 
{
-        self.memory_manager.register_consumer(memory_consumer);
+    pub fn register_requester(&self, id: &MemoryConsumerId) {
+        self.memory_manager.register_requester(id);
     }
 
-    /// Drop the consumer from get tracked
-    pub fn drop_consumer(&self, id: &MemoryConsumerId) {
-        self.memory_manager.drop_consumer(id)
+    /// Drop the consumer from get tracked, reclaim memory
+    pub fn drop_consumer(&self, id: &MemoryConsumerId, mem_used: usize) {
+        self.memory_manager.drop_consumer(id, mem_used)
+    }
+
+    /// Grow tracker memory of `delta`
+    pub fn grow_tracker_usage(&self, delta: usize) {
+        self.memory_manager.grow_tracker_usage(delta)
+    }
+
+    /// Shrink tracker memory of `delta`
+    pub fn shrink_tracker_usage(&self, delta: usize) {
+        self.memory_manager.shrink_tracker_usage(delta)
     }
 }
 
diff --git a/datafusion/src/physical_plan/common.rs 
b/datafusion/src/physical_plan/common.rs
index 390f004..bc4400d 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -20,7 +20,7 @@
 use super::{RecordBatchStream, SendableRecordBatchStream};
 use crate::error::{DataFusionError, Result};
 use crate::execution::runtime_env::RuntimeEnv;
-use crate::physical_plan::metrics::BaselineMetrics;
+use crate::physical_plan::metrics::MemTrackingMetrics;
 use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics};
 use arrow::compute::concat;
 use arrow::datatypes::{Schema, SchemaRef};
@@ -43,7 +43,7 @@ pub struct SizedRecordBatchStream {
     schema: SchemaRef,
     batches: Vec<Arc<RecordBatch>>,
     index: usize,
-    baseline_metrics: BaselineMetrics,
+    metrics: MemTrackingMetrics,
 }
 
 impl SizedRecordBatchStream {
@@ -51,13 +51,15 @@ impl SizedRecordBatchStream {
     pub fn new(
         schema: SchemaRef,
         batches: Vec<Arc<RecordBatch>>,
-        baseline_metrics: BaselineMetrics,
+        metrics: MemTrackingMetrics,
     ) -> Self {
+        let size = batches.iter().map(|b| batch_byte_size(b)).sum::<usize>();
+        metrics.init_mem_used(size);
         SizedRecordBatchStream {
             schema,
             index: 0,
             batches,
-            baseline_metrics,
+            metrics,
         }
     }
 }
@@ -75,7 +77,7 @@ impl Stream for SizedRecordBatchStream {
         } else {
             None
         });
-        self.baseline_metrics.record_poll(poll)
+        self.metrics.record_poll(poll)
     }
 }
 
diff --git a/datafusion/src/physical_plan/explain.rs 
b/datafusion/src/physical_plan/explain.rs
index f827dc3..eb18926 100644
--- a/datafusion/src/physical_plan/explain.rs
+++ b/datafusion/src/physical_plan/explain.rs
@@ -32,7 +32,7 @@ use arrow::{array::StringBuilder, datatypes::SchemaRef, 
record_batch::RecordBatc
 
 use super::SendableRecordBatchStream;
 use crate::execution::runtime_env::RuntimeEnv;
-use crate::physical_plan::metrics::{BaselineMetrics, ExecutionPlanMetricsSet};
+use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, 
MemTrackingMetrics};
 use async_trait::async_trait;
 
 /// Explain execution plan operator. This operator contains the string
@@ -148,12 +148,12 @@ impl ExecutionPlan for ExplainExec {
         )?;
 
         let metrics = ExecutionPlanMetricsSet::new();
-        let baseline_metrics = BaselineMetrics::new(&metrics, partition);
+        let tracking_metrics = MemTrackingMetrics::new(&metrics, partition);
 
         Ok(Box::pin(SizedRecordBatchStream::new(
             self.schema.clone(),
             vec![Arc::new(record_batch)],
-            baseline_metrics,
+            tracking_metrics,
         )))
     }
 
diff --git a/datafusion/src/physical_plan/metrics/aggregated.rs 
b/datafusion/src/physical_plan/metrics/aggregated.rs
deleted file mode 100644
index c55cc16..0000000
--- a/datafusion/src/physical_plan/metrics/aggregated.rs
+++ /dev/null
@@ -1,155 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! Metrics common for complex operators with multiple steps.
-
-use crate::physical_plan::metrics::{
-    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricsSet, Time,
-};
-use std::sync::Arc;
-use std::time::Duration;
-
-#[derive(Debug, Clone)]
-/// Aggregates all metrics during a complex operation, which is composed of 
multiple steps and
-/// each stage reports its statistics separately.
-/// Give sort as an example, when the dataset is more significant than 
available memory, it will report
-/// multiple in-mem sort metrics and final merge-sort  metrics from 
`SortPreservingMergeStream`.
-/// Therefore, We need a separation of metrics for which are final metrics 
(for output_rows accumulation),
-/// and which are intermediate metrics that we only account for 
elapsed_compute time.
-pub struct AggregatedMetricsSet {
-    intermediate: Arc<std::sync::Mutex<Vec<ExecutionPlanMetricsSet>>>,
-    final_: Arc<std::sync::Mutex<Vec<ExecutionPlanMetricsSet>>>,
-}
-
-impl Default for AggregatedMetricsSet {
-    fn default() -> Self {
-        Self::new()
-    }
-}
-
-impl AggregatedMetricsSet {
-    /// Create a new aggregated set
-    pub fn new() -> Self {
-        Self {
-            intermediate: Arc::new(std::sync::Mutex::new(vec![])),
-            final_: Arc::new(std::sync::Mutex::new(vec![])),
-        }
-    }
-
-    /// create a new intermediate baseline
-    pub fn new_intermediate_baseline(&self, partition: usize) -> 
BaselineMetrics {
-        let ms = ExecutionPlanMetricsSet::new();
-        let result = BaselineMetrics::new(&ms, partition);
-        self.intermediate.lock().unwrap().push(ms);
-        result
-    }
-
-    /// create a new final baseline
-    pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics {
-        let ms = ExecutionPlanMetricsSet::new();
-        let result = BaselineMetrics::new(&ms, partition);
-        self.final_.lock().unwrap().push(ms);
-        result
-    }
-
-    fn merge_compute_time(&self, dest: &Time) {
-        let time1 = self
-            .intermediate
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| {
-                es.clone_inner()
-                    .elapsed_compute()
-                    .map_or(0u64, |v| v as u64)
-            })
-            .sum();
-        let time2 = self
-            .final_
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| {
-                es.clone_inner()
-                    .elapsed_compute()
-                    .map_or(0u64, |v| v as u64)
-            })
-            .sum();
-        dest.add_duration(Duration::from_nanos(time1));
-        dest.add_duration(Duration::from_nanos(time2));
-    }
-
-    fn merge_spill_count(&self, dest: &Count) {
-        let count1 = self
-            .intermediate
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| es.clone_inner().spill_count().map_or(0, |v| v))
-            .sum();
-        let count2 = self
-            .final_
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| es.clone_inner().spill_count().map_or(0, |v| v))
-            .sum();
-        dest.add(count1);
-        dest.add(count2);
-    }
-
-    fn merge_spilled_bytes(&self, dest: &Count) {
-        let count1 = self
-            .intermediate
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v))
-            .sum();
-        let count2 = self
-            .final_
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| es.clone_inner().spilled_bytes().map_or(0, |v| v))
-            .sum();
-        dest.add(count1);
-        dest.add(count2);
-    }
-
-    fn merge_output_count(&self, dest: &Count) {
-        let count = self
-            .final_
-            .lock()
-            .unwrap()
-            .iter()
-            .map(|es| es.clone_inner().output_rows().map_or(0, |v| v))
-            .sum();
-        dest.add(count);
-    }
-
-    /// Aggregate all metrics into a one
-    pub fn aggregate_all(&self) -> MetricsSet {
-        let metrics = ExecutionPlanMetricsSet::new();
-        let baseline = BaselineMetrics::new(&metrics, 0);
-        self.merge_compute_time(baseline.elapsed_compute());
-        self.merge_spill_count(baseline.spill_count());
-        self.merge_spilled_bytes(baseline.spilled_bytes());
-        self.merge_output_count(baseline.output_rows());
-        metrics.clone_inner()
-    }
-}
diff --git a/datafusion/src/physical_plan/metrics/baseline.rs 
b/datafusion/src/physical_plan/metrics/baseline.rs
index 50c49ec..8dff5ee 100644
--- a/datafusion/src/physical_plan/metrics/baseline.rs
+++ b/datafusion/src/physical_plan/metrics/baseline.rs
@@ -113,7 +113,7 @@ impl BaselineMetrics {
     /// Records the fact that this operator's execution is complete
     /// (recording the `end_time` metric).
     ///
-    /// Note care should be taken to call `done()` maually if
+    /// Note care should be taken to call `done()` manually if
     /// `BaselineMetrics` is not `drop`ped immediately upon operator
     /// completion, as async streams may not be dropped immediately
     /// depending on the consumer.
@@ -129,6 +129,13 @@ impl BaselineMetrics {
         self.output_rows.add(num_rows);
     }
 
+    /// If not previously recorded `done()`, record
+    pub fn try_done(&self) {
+        if self.end_time.value().is_none() {
+            self.end_time.record()
+        }
+    }
+
     /// Process a poll result of a stream producing output for an
     /// operator, recording the output rows and stream done time and
     /// returning the same poll result
@@ -151,10 +158,7 @@ impl BaselineMetrics {
 
 impl Drop for BaselineMetrics {
     fn drop(&mut self) {
-        // if not previously recorded, record
-        if self.end_time.value().is_none() {
-            self.end_time.record()
-        }
+        self.try_done()
     }
 }
 
diff --git a/datafusion/src/physical_plan/metrics/composite.rs 
b/datafusion/src/physical_plan/metrics/composite.rs
new file mode 100644
index 0000000..cd4d5c3
--- /dev/null
+++ b/datafusion/src/physical_plan/metrics/composite.rs
@@ -0,0 +1,205 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Metrics common for complex operators with multiple steps.
+
+use crate::execution::runtime_env::RuntimeEnv;
+use crate::physical_plan::metrics::tracker::MemTrackingMetrics;
+use crate::physical_plan::metrics::{
+    BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricValue, MetricsSet, 
Time,
+    Timestamp,
+};
+use crate::physical_plan::Metric;
+use chrono::{TimeZone, Utc};
+use std::sync::Arc;
+use std::time::Duration;
+
+#[derive(Debug, Clone)]
+/// Collects all metrics during a complex operation, which is composed of 
multiple steps and
+/// each stage reports its statistics separately.
+/// Give sort as an example, when the dataset is more significant than 
available memory, it will report
+/// multiple in-mem sort metrics and final merge-sort  metrics from 
`SortPreservingMergeStream`.
+/// Therefore, We need a separation of metrics for which are final metrics 
(for output_rows accumulation),
+/// and which are intermediate metrics that we only account for 
elapsed_compute time.
+pub struct CompositeMetricsSet {
+    mid: ExecutionPlanMetricsSet,
+    final_: ExecutionPlanMetricsSet,
+}
+
+impl Default for CompositeMetricsSet {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl CompositeMetricsSet {
+    /// Create a new aggregated set
+    pub fn new() -> Self {
+        Self {
+            mid: ExecutionPlanMetricsSet::new(),
+            final_: ExecutionPlanMetricsSet::new(),
+        }
+    }
+
+    /// create a new intermediate baseline
+    pub fn new_intermediate_baseline(&self, partition: usize) -> 
BaselineMetrics {
+        BaselineMetrics::new(&self.mid, partition)
+    }
+
+    /// create a new final baseline
+    pub fn new_final_baseline(&self, partition: usize) -> BaselineMetrics {
+        BaselineMetrics::new(&self.final_, partition)
+    }
+
+    /// create a new intermediate memory tracking metrics
+    pub fn new_intermediate_tracking(
+        &self,
+        partition: usize,
+        runtime: Arc<RuntimeEnv>,
+    ) -> MemTrackingMetrics {
+        MemTrackingMetrics::new_with_rt(&self.mid, partition, runtime)
+    }
+
+    /// create a new final memory tracking metrics
+    pub fn new_final_tracking(
+        &self,
+        partition: usize,
+        runtime: Arc<RuntimeEnv>,
+    ) -> MemTrackingMetrics {
+        MemTrackingMetrics::new_with_rt(&self.final_, partition, runtime)
+    }
+
+    fn merge_compute_time(&self, dest: &Time) {
+        let time1 = self
+            .mid
+            .clone_inner()
+            .elapsed_compute()
+            .map_or(0u64, |v| v as u64);
+        let time2 = self
+            .final_
+            .clone_inner()
+            .elapsed_compute()
+            .map_or(0u64, |v| v as u64);
+        dest.add_duration(Duration::from_nanos(time1));
+        dest.add_duration(Duration::from_nanos(time2));
+    }
+
+    fn merge_spill_count(&self, dest: &Count) {
+        let count1 = self.mid.clone_inner().spill_count().map_or(0, |v| v);
+        let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v);
+        dest.add(count1);
+        dest.add(count2);
+    }
+
+    fn merge_spilled_bytes(&self, dest: &Count) {
+        let count1 = self.mid.clone_inner().spilled_bytes().map_or(0, |v| v);
+        let count2 = self.final_.clone_inner().spill_count().map_or(0, |v| v);
+        dest.add(count1);
+        dest.add(count2);
+    }
+
+    fn merge_output_count(&self, dest: &Count) {
+        let count = self.final_.clone_inner().output_rows().map_or(0, |v| v);
+        dest.add(count);
+    }
+
+    fn merge_start_time(&self, dest: &Timestamp) {
+        let start1 = self
+            .mid
+            .clone_inner()
+            .sum(|metric| matches!(metric.value(), 
MetricValue::StartTimestamp(_)))
+            .map(|v| v.as_usize());
+        let start2 = self
+            .final_
+            .clone_inner()
+            .sum(|metric| matches!(metric.value(), 
MetricValue::StartTimestamp(_)))
+            .map(|v| v.as_usize());
+        match (start1, start2) {
+            (Some(start1), Some(start2)) => {
+                dest.set(Utc.timestamp_nanos(start1.min(start2) as i64))
+            }
+            (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as 
i64)),
+            (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as 
i64)),
+            (None, None) => {}
+        }
+    }
+
+    fn merge_end_time(&self, dest: &Timestamp) {
+        let start1 = self
+            .mid
+            .clone_inner()
+            .sum(|metric| matches!(metric.value(), 
MetricValue::EndTimestamp(_)))
+            .map(|v| v.as_usize());
+        let start2 = self
+            .final_
+            .clone_inner()
+            .sum(|metric| matches!(metric.value(), 
MetricValue::EndTimestamp(_)))
+            .map(|v| v.as_usize());
+        match (start1, start2) {
+            (Some(start1), Some(start2)) => {
+                dest.set(Utc.timestamp_nanos(start1.max(start2) as i64))
+            }
+            (Some(start1), None) => dest.set(Utc.timestamp_nanos(start1 as 
i64)),
+            (None, Some(start2)) => dest.set(Utc.timestamp_nanos(start2 as 
i64)),
+            (None, None) => {}
+        }
+    }
+
+    /// Aggregate all metrics into a one
+    pub fn aggregate_all(&self) -> MetricsSet {
+        let mut metrics = MetricsSet::new();
+        let elapsed_time = Time::new();
+        let spill_count = Count::new();
+        let spilled_bytes = Count::new();
+        let output_count = Count::new();
+        let start_time = Timestamp::new();
+        let end_time = Timestamp::new();
+
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::ElapsedCompute(elapsed_time.clone()),
+            None,
+        )));
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::SpillCount(spill_count.clone()),
+            None,
+        )));
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::SpilledBytes(spilled_bytes.clone()),
+            None,
+        )));
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::OutputRows(output_count.clone()),
+            None,
+        )));
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::StartTimestamp(start_time.clone()),
+            None,
+        )));
+        metrics.push(Arc::new(Metric::new(
+            MetricValue::EndTimestamp(end_time.clone()),
+            None,
+        )));
+
+        self.merge_compute_time(&elapsed_time);
+        self.merge_spill_count(&spill_count);
+        self.merge_spilled_bytes(&spilled_bytes);
+        self.merge_output_count(&output_count);
+        self.merge_start_time(&start_time);
+        self.merge_end_time(&end_time);
+        metrics
+    }
+}
diff --git a/datafusion/src/physical_plan/metrics/mod.rs 
b/datafusion/src/physical_plan/metrics/mod.rs
index d489599..e609beb 100644
--- a/datafusion/src/physical_plan/metrics/mod.rs
+++ b/datafusion/src/physical_plan/metrics/mod.rs
@@ -17,9 +17,10 @@
 
 //! Metrics for recording information about execution
 
-mod aggregated;
 mod baseline;
 mod builder;
+mod composite;
+mod tracker;
 mod value;
 
 use std::{
@@ -31,9 +32,10 @@ use std::{
 use hashbrown::HashMap;
 
 // public exports
-pub use aggregated::AggregatedMetricsSet;
 pub use baseline::{BaselineMetrics, RecordOutput};
 pub use builder::MetricBuilder;
+pub use composite::CompositeMetricsSet;
+pub use tracker::MemTrackingMetrics;
 pub use value::{Count, Gauge, MetricValue, ScopedTimerGuard, Time, Timestamp};
 
 /// Something that tracks a value of interest (metric) of a DataFusion
diff --git a/datafusion/src/physical_plan/metrics/tracker.rs 
b/datafusion/src/physical_plan/metrics/tracker.rs
new file mode 100644
index 0000000..bdceadb
--- /dev/null
+++ b/datafusion/src/physical_plan/metrics/tracker.rs
@@ -0,0 +1,131 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Metrics with memory usage tracking capability
+
+use crate::execution::runtime_env::RuntimeEnv;
+use crate::execution::MemoryConsumerId;
+use crate::physical_plan::metrics::{
+    BaselineMetrics, Count, ExecutionPlanMetricsSet, Time,
+};
+use std::sync::Arc;
+use std::task::Poll;
+
+use arrow::{error::ArrowError, record_batch::RecordBatch};
+
+/// Simplified version of tracking memory consumer,
+/// see also: 
[`Tracking`](crate::execution::memory_manager::ConsumerType::Tracking)
+///
+/// You could use this to replace [BaselineMetrics], report the memory,
+/// and get the memory usage bookkeeping in the memory manager easily.
+#[derive(Debug)]
+pub struct MemTrackingMetrics {
+    id: MemoryConsumerId,
+    runtime: Option<Arc<RuntimeEnv>>,
+    metrics: BaselineMetrics,
+}
+
+/// Delegates most of the metrics functionalities to the inner BaselineMetrics,
+/// intercept memory metrics functionalities and do memory manager bookkeeping.
+impl MemTrackingMetrics {
+    /// Create metrics similar to [BaselineMetrics]
+    pub fn new(metrics: &ExecutionPlanMetricsSet, partition: usize) -> Self {
+        let id = MemoryConsumerId::new(partition);
+        Self {
+            id,
+            runtime: None,
+            metrics: BaselineMetrics::new(metrics, partition),
+        }
+    }
+
+    /// Create memory tracking metrics with reference to runtime
+    pub fn new_with_rt(
+        metrics: &ExecutionPlanMetricsSet,
+        partition: usize,
+        runtime: Arc<RuntimeEnv>,
+    ) -> Self {
+        let id = MemoryConsumerId::new(partition);
+        Self {
+            id,
+            runtime: Some(runtime),
+            metrics: BaselineMetrics::new(metrics, partition),
+        }
+    }
+
+    /// return the metric for cpu time spend in this operator
+    pub fn elapsed_compute(&self) -> &Time {
+        self.metrics.elapsed_compute()
+    }
+
+    /// return the size for current memory usage
+    pub fn mem_used(&self) -> usize {
+        self.metrics.mem_used().value()
+    }
+
+    /// setup initial memory usage and register it with memory manager
+    pub fn init_mem_used(&self, size: usize) {
+        self.metrics.mem_used().set(size);
+        if let Some(rt) = self.runtime.as_ref() {
+            rt.memory_manager.grow_tracker_usage(size);
+        }
+    }
+
+    /// return the metric for the total number of output rows produced
+    pub fn output_rows(&self) -> &Count {
+        self.metrics.output_rows()
+    }
+
+    /// Records the fact that this operator's execution is complete
+    /// (recording the `end_time` metric).
+    ///
+    /// Note care should be taken to call `done()` manually if
+    /// `MemTrackingMetrics` is not `drop`ped immediately upon operator
+    /// completion, as async streams may not be dropped immediately
+    /// depending on the consumer.
+    pub fn done(&self) {
+        self.metrics.done()
+    }
+
+    /// Record that some number of rows have been produced as output
+    ///
+    /// See the [`RecordOutput`] for conveniently recording record
+    /// batch output for other thing
+    pub fn record_output(&self, num_rows: usize) {
+        self.metrics.record_output(num_rows)
+    }
+
+    /// Process a poll result of a stream producing output for an
+    /// operator, recording the output rows and stream done time and
+    /// returning the same poll result
+    pub fn record_poll(
+        &self,
+        poll: Poll<Option<Result<RecordBatch, ArrowError>>>,
+    ) -> Poll<Option<Result<RecordBatch, ArrowError>>> {
+        self.metrics.record_poll(poll)
+    }
+}
+
+impl Drop for MemTrackingMetrics {
+    fn drop(&mut self) {
+        self.metrics.try_done();
+        if self.mem_used() != 0 {
+            if let Some(rt) = self.runtime.as_ref() {
+                rt.drop_consumer(&self.id, self.mem_used());
+            }
+        }
+    }
+}
diff --git a/datafusion/src/physical_plan/sorts/mod.rs 
b/datafusion/src/physical_plan/sorts/mod.rs
index 7855568..64ec291 100644
--- a/datafusion/src/physical_plan/sorts/mod.rs
+++ b/datafusion/src/physical_plan/sorts/mod.rs
@@ -248,15 +248,6 @@ enum StreamWrapper {
     Stream(Option<SortedStream>),
 }
 
-impl StreamWrapper {
-    fn mem_used(&self) -> usize {
-        match &self {
-            StreamWrapper::Stream(Some(s)) => s.mem_used,
-            _ => 0,
-        }
-    }
-}
-
 impl Stream for StreamWrapper {
     type Item = ArrowResult<RecordBatch>;
 
diff --git a/datafusion/src/physical_plan/sorts/sort.rs 
b/datafusion/src/physical_plan/sorts/sort.rs
index d40d6cf..7266b6c 100644
--- a/datafusion/src/physical_plan/sorts/sort.rs
+++ b/datafusion/src/physical_plan/sorts/sort.rs
@@ -26,7 +26,9 @@ use crate::execution::memory_manager::{
 use crate::execution::runtime_env::RuntimeEnv;
 use crate::physical_plan::common::{batch_byte_size, IPCWriter, 
SizedRecordBatchStream};
 use crate::physical_plan::expressions::PhysicalSortExpr;
-use crate::physical_plan::metrics::{AggregatedMetricsSet, BaselineMetrics, 
MetricsSet};
+use crate::physical_plan::metrics::{
+    BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet,
+};
 use 
crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream;
 use crate::physical_plan::sorts::SortedStream;
 use crate::physical_plan::stream::RecordBatchReceiverStream;
@@ -73,8 +75,8 @@ struct ExternalSorter {
     /// Sort expressions
     expr: Vec<PhysicalSortExpr>,
     runtime: Arc<RuntimeEnv>,
-    metrics: AggregatedMetricsSet,
-    inner_metrics: BaselineMetrics,
+    metrics_set: CompositeMetricsSet,
+    metrics: BaselineMetrics,
 }
 
 impl ExternalSorter {
@@ -82,10 +84,10 @@ impl ExternalSorter {
         partition_id: usize,
         schema: SchemaRef,
         expr: Vec<PhysicalSortExpr>,
-        metrics: AggregatedMetricsSet,
+        metrics_set: CompositeMetricsSet,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
-        let inner_metrics = metrics.new_intermediate_baseline(partition_id);
+        let metrics = metrics_set.new_intermediate_baseline(partition_id);
         Self {
             id: MemoryConsumerId::new(partition_id),
             schema,
@@ -93,8 +95,8 @@ impl ExternalSorter {
             spills: Mutex::new(vec![]),
             expr,
             runtime,
+            metrics_set,
             metrics,
-            inner_metrics,
         }
     }
 
@@ -102,7 +104,7 @@ impl ExternalSorter {
         if input.num_rows() > 0 {
             let size = batch_byte_size(&input);
             self.try_grow(size).await?;
-            self.inner_metrics.mem_used().add(size);
+            self.metrics.mem_used().add(size);
             let mut in_mem_batches = self.in_mem_batches.lock().await;
             in_mem_batches.push(input);
         }
@@ -120,16 +122,18 @@ impl ExternalSorter {
         let mut in_mem_batches = self.in_mem_batches.lock().await;
 
         if self.spilled_before().await {
-            let baseline_metrics = 
self.metrics.new_intermediate_baseline(partition);
+            let tracking_metrics = self
+                .metrics_set
+                .new_intermediate_tracking(partition, self.runtime.clone());
             let mut streams: Vec<SortedStream> = vec![];
             if in_mem_batches.len() > 0 {
                 let in_mem_stream = in_mem_partial_sort(
                     &mut *in_mem_batches,
                     self.schema.clone(),
                     &self.expr,
-                    baseline_metrics,
+                    tracking_metrics,
                 )?;
-                let prev_used = self.inner_metrics.mem_used().set(0);
+                let prev_used = self.metrics.mem_used().set(0);
                 streams.push(SortedStream::new(in_mem_stream, prev_used));
             }
 
@@ -139,25 +143,28 @@ impl ExternalSorter {
                 let stream = read_spill_as_stream(spill, self.schema.clone())?;
                 streams.push(SortedStream::new(stream, 0));
             }
-            let baseline_metrics = self.metrics.new_final_baseline(partition);
+            let tracking_metrics = self
+                .metrics_set
+                .new_final_tracking(partition, self.runtime.clone());
             Ok(Box::pin(SortPreservingMergeStream::new_from_streams(
                 streams,
                 self.schema.clone(),
                 &self.expr,
-                baseline_metrics,
-                partition,
+                tracking_metrics,
                 self.runtime.clone(),
             )))
         } else if in_mem_batches.len() > 0 {
-            let baseline_metrics = self.metrics.new_final_baseline(partition);
+            let tracking_metrics = self
+                .metrics_set
+                .new_final_tracking(partition, self.runtime.clone());
             let result = in_mem_partial_sort(
                 &mut *in_mem_batches,
                 self.schema.clone(),
                 &self.expr,
-                baseline_metrics,
+                tracking_metrics,
             );
-            self.inner_metrics.mem_used().set(0);
-            // TODO: the result size is not tracked
+            // Report to the memory manager we are no longer using memory
+            self.metrics.mem_used().set(0);
             result
         } else {
             Ok(Box::pin(EmptyRecordBatchStream::new(self.schema.clone())))
@@ -165,15 +172,15 @@ impl ExternalSorter {
     }
 
     fn used(&self) -> usize {
-        self.inner_metrics.mem_used().value()
+        self.metrics.mem_used().value()
     }
 
     fn spilled_bytes(&self) -> usize {
-        self.inner_metrics.spilled_bytes().value()
+        self.metrics.spilled_bytes().value()
     }
 
     fn spill_count(&self) -> usize {
-        self.inner_metrics.spill_count().value()
+        self.metrics.spill_count().value()
     }
 }
 
@@ -188,6 +195,12 @@ impl Debug for ExternalSorter {
     }
 }
 
+impl Drop for ExternalSorter {
+    fn drop(&mut self) {
+        self.runtime.drop_consumer(self.id(), self.used());
+    }
+}
+
 #[async_trait]
 impl MemoryConsumer for ExternalSorter {
     fn name(&self) -> String {
@@ -222,27 +235,29 @@ impl MemoryConsumer for ExternalSorter {
             return Ok(0);
         }
 
-        let baseline_metrics = 
self.metrics.new_intermediate_baseline(partition);
+        let tracking_metrics = self
+            .metrics_set
+            .new_intermediate_tracking(partition, self.runtime.clone());
 
         let spillfile = self.runtime.disk_manager.create_tmp_file()?;
         let stream = in_mem_partial_sort(
             &mut *in_mem_batches,
             self.schema.clone(),
             &*self.expr,
-            baseline_metrics,
+            tracking_metrics,
         );
 
         spill_partial_sorted_stream(&mut stream?, spillfile.path(), 
self.schema.clone())
             .await?;
         let mut spills = self.spills.lock().await;
-        let used = self.inner_metrics.mem_used().set(0);
-        self.inner_metrics.record_spill(used);
+        let used = self.metrics.mem_used().set(0);
+        self.metrics.record_spill(used);
         spills.push(spillfile);
         Ok(used)
     }
 
     fn mem_used(&self) -> usize {
-        self.inner_metrics.mem_used().value()
+        self.metrics.mem_used().value()
     }
 }
 
@@ -251,14 +266,14 @@ fn in_mem_partial_sort(
     buffered_batches: &mut Vec<RecordBatch>,
     schema: SchemaRef,
     expressions: &[PhysicalSortExpr],
-    baseline_metrics: BaselineMetrics,
+    tracking_metrics: MemTrackingMetrics,
 ) -> Result<SendableRecordBatchStream> {
     assert_ne!(buffered_batches.len(), 0);
 
     let result = {
         // NB timer records time taken on drop, so there are no
         // calls to `timer.done()` below.
-        let _timer = baseline_metrics.elapsed_compute().timer();
+        let _timer = tracking_metrics.elapsed_compute().timer();
 
         let pre_sort = if buffered_batches.len() == 1 {
             buffered_batches.pop()
@@ -276,7 +291,7 @@ fn in_mem_partial_sort(
     Ok(Box::pin(SizedRecordBatchStream::new(
         schema,
         vec![Arc::new(result.unwrap())],
-        baseline_metrics,
+        tracking_metrics,
     )))
 }
 
@@ -357,7 +372,7 @@ pub struct SortExec {
     /// Sort expressions
     expr: Vec<PhysicalSortExpr>,
     /// Containing all metrics set created during sort
-    all_metrics: AggregatedMetricsSet,
+    metrics_set: CompositeMetricsSet,
     /// Preserve partitions of input plan
     preserve_partitioning: bool,
 }
@@ -381,7 +396,7 @@ impl SortExec {
         Self {
             expr,
             input,
-            all_metrics: AggregatedMetricsSet::new(),
+            metrics_set: CompositeMetricsSet::new(),
             preserve_partitioning,
         }
     }
@@ -470,14 +485,14 @@ impl ExecutionPlan for SortExec {
             input,
             partition,
             self.expr.clone(),
-            self.all_metrics.clone(),
+            self.metrics_set.clone(),
             runtime,
         )
         .await
     }
 
     fn metrics(&self) -> Option<MetricsSet> {
-        Some(self.all_metrics.aggregate_all())
+        Some(self.metrics_set.aggregate_all())
     }
 
     fn fmt_as(
@@ -537,27 +552,23 @@ async fn do_sort(
     mut input: SendableRecordBatchStream,
     partition_id: usize,
     expr: Vec<PhysicalSortExpr>,
-    metrics: AggregatedMetricsSet,
+    metrics_set: CompositeMetricsSet,
     runtime: Arc<RuntimeEnv>,
 ) -> Result<SendableRecordBatchStream> {
     let schema = input.schema();
-    let sorter = Arc::new(ExternalSorter::new(
+    let sorter = ExternalSorter::new(
         partition_id,
         schema.clone(),
         expr,
-        metrics,
+        metrics_set,
         runtime.clone(),
-    ));
-    runtime.register_consumer(&(sorter.clone() as Arc<dyn MemoryConsumer>));
-
+    );
+    runtime.register_requester(sorter.id());
     while let Some(batch) = input.next().await {
         let batch = batch?;
         sorter.insert_batch(batch).await?;
     }
-
-    let result = sorter.sort().await;
-    runtime.drop_consumer(sorter.id());
-    result
+    sorter.sort().await
 }
 
 #[cfg(test)]
diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs 
b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
index 2ac468b..7b9d5d5 100644
--- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs
@@ -19,11 +19,11 @@
 
 use crate::physical_plan::common::AbortOnDropMany;
 use crate::physical_plan::metrics::{
-    BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
+    ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet,
 };
 use std::any::Any;
 use std::collections::{BinaryHeap, VecDeque};
-use std::fmt::{Debug, Formatter};
+use std::fmt::Debug;
 use std::pin::Pin;
 use std::sync::{Arc, Mutex};
 use std::task::{Context, Poll};
@@ -41,9 +41,6 @@ use futures::stream::FusedStream;
 use futures::{Stream, StreamExt};
 
 use crate::error::{DataFusionError, Result};
-use crate::execution::memory_manager::{
-    ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager,
-};
 use crate::execution::runtime_env::RuntimeEnv;
 use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, 
StreamWrapper};
 use crate::physical_plan::{
@@ -161,7 +158,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
             )));
         }
 
-        let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
+        let tracking_metrics = MemTrackingMetrics::new(&self.metrics, 
partition);
 
         let input_partitions = 
self.input.output_partitioning().partition_count();
         match input_partitions {
@@ -193,8 +190,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
                     AbortOnDropMany(join_handles),
                     self.schema(),
                     &self.expr,
-                    baseline_metrics,
-                    partition,
+                    tracking_metrics,
                     runtime,
                 )))
             }
@@ -223,36 +219,19 @@ impl ExecutionPlan for SortPreservingMergeExec {
     }
 }
 
+#[derive(Debug)]
 struct MergingStreams {
-    /// ConsumerId
-    id: MemoryConsumerId,
     /// The sorted input streams to merge together
     streams: Mutex<Vec<StreamWrapper>>,
     /// number of streams
     num_streams: usize,
-    /// Runtime
-    runtime: Arc<RuntimeEnv>,
-}
-
-impl Debug for MergingStreams {
-    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
-        f.debug_struct("MergingStreams")
-            .field("id", &self.id())
-            .finish()
-    }
 }
 
 impl MergingStreams {
-    fn new(
-        partition: usize,
-        input_streams: Vec<StreamWrapper>,
-        runtime: Arc<RuntimeEnv>,
-    ) -> Self {
+    fn new(input_streams: Vec<StreamWrapper>) -> Self {
         Self {
-            id: MemoryConsumerId::new(partition),
             num_streams: input_streams.len(),
             streams: Mutex::new(input_streams),
-            runtime,
         }
     }
 
@@ -261,45 +240,13 @@ impl MergingStreams {
     }
 }
 
-#[async_trait]
-impl MemoryConsumer for MergingStreams {
-    fn name(&self) -> String {
-        "MergingStreams".to_owned()
-    }
-
-    fn id(&self) -> &MemoryConsumerId {
-        &self.id
-    }
-
-    fn memory_manager(&self) -> Arc<MemoryManager> {
-        self.runtime.memory_manager.clone()
-    }
-
-    fn type_(&self) -> &ConsumerType {
-        &ConsumerType::Tracking
-    }
-
-    async fn spill(&self) -> Result<usize> {
-        return Err(DataFusionError::Internal(format!(
-            "Calling spill on a tracking only consumer {}, {}",
-            self.name(),
-            self.id,
-        )));
-    }
-
-    fn mem_used(&self) -> usize {
-        let streams = self.streams.lock().unwrap();
-        streams.iter().map(StreamWrapper::mem_used).sum::<usize>()
-    }
-}
-
 #[derive(Debug)]
 pub(crate) struct SortPreservingMergeStream {
     /// The schema of the RecordBatches yielded by this stream
     schema: SchemaRef,
 
     /// The sorted input streams to merge together
-    streams: Arc<MergingStreams>,
+    streams: MergingStreams,
 
     /// Drop helper for tasks feeding the [`receivers`](Self::receivers)
     _drop_helper: AbortOnDropMany<()>,
@@ -324,7 +271,7 @@ pub(crate) struct SortPreservingMergeStream {
     sort_options: Arc<Vec<SortOptions>>,
 
     /// used to record execution metrics
-    baseline_metrics: BaselineMetrics,
+    tracking_metrics: MemTrackingMetrics,
 
     /// If the stream has encountered an error
     aborted: bool,
@@ -335,25 +282,17 @@ pub(crate) struct SortPreservingMergeStream {
     /// min heap for record comparison
     min_heap: BinaryHeap<SortKeyCursor>,
 
-    /// runtime
-    runtime: Arc<RuntimeEnv>,
-}
-
-impl Drop for SortPreservingMergeStream {
-    fn drop(&mut self) {
-        self.runtime.drop_consumer(self.streams.id())
-    }
+    /// target batch size
+    batch_size: usize,
 }
 
 impl SortPreservingMergeStream {
-    #[allow(clippy::too_many_arguments)]
     pub(crate) fn new_from_receivers(
         receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
         _drop_helper: AbortOnDropMany<()>,
         schema: SchemaRef,
         expressions: &[PhysicalSortExpr],
-        baseline_metrics: BaselineMetrics,
-        partition: usize,
+        tracking_metrics: MemTrackingMetrics,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
         let stream_count = receivers.len();
@@ -362,23 +301,21 @@ impl SortPreservingMergeStream {
             .map(|_| VecDeque::new())
             .collect();
         let wrappers = 
receivers.into_iter().map(StreamWrapper::Receiver).collect();
-        let streams = Arc::new(MergingStreams::new(partition, wrappers, 
runtime.clone()));
-        runtime.register_consumer(&(streams.clone() as Arc<dyn 
MemoryConsumer>));
 
         SortPreservingMergeStream {
             schema,
             batches,
             cursor_finished: vec![true; stream_count],
-            streams,
+            streams: MergingStreams::new(wrappers),
             _drop_helper,
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
             sort_options: Arc::new(expressions.iter().map(|x| 
x.options).collect()),
-            baseline_metrics,
+            tracking_metrics,
             aborted: false,
             in_progress: vec![],
             next_batch_id: 0,
             min_heap: BinaryHeap::with_capacity(stream_count),
-            runtime,
+            batch_size: runtime.batch_size(),
         }
     }
 
@@ -386,8 +323,7 @@ impl SortPreservingMergeStream {
         streams: Vec<SortedStream>,
         schema: SchemaRef,
         expressions: &[PhysicalSortExpr],
-        baseline_metrics: BaselineMetrics,
-        partition: usize,
+        tracking_metrics: MemTrackingMetrics,
         runtime: Arc<RuntimeEnv>,
     ) -> Self {
         let stream_count = streams.len();
@@ -395,27 +331,26 @@ impl SortPreservingMergeStream {
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
+        tracking_metrics.init_mem_used(streams.iter().map(|s| 
s.mem_used).sum());
         let wrappers = streams
             .into_iter()
             .map(|s| StreamWrapper::Stream(Some(s)))
             .collect();
-        let streams = Arc::new(MergingStreams::new(partition, wrappers, 
runtime.clone()));
-        runtime.register_consumer(&(streams.clone() as Arc<dyn 
MemoryConsumer>));
 
         Self {
             schema,
             batches,
             cursor_finished: vec![true; stream_count],
-            streams,
+            streams: MergingStreams::new(wrappers),
             _drop_helper: AbortOnDropMany(vec![]),
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
             sort_options: Arc::new(expressions.iter().map(|x| 
x.options).collect()),
-            baseline_metrics,
+            tracking_metrics,
             aborted: false,
             in_progress: vec![],
             next_batch_id: 0,
             min_heap: BinaryHeap::with_capacity(stream_count),
-            runtime,
+            batch_size: runtime.batch_size(),
         }
     }
 
@@ -577,7 +512,7 @@ impl Stream for SortPreservingMergeStream {
         cx: &mut Context<'_>,
     ) -> Poll<Option<Self::Item>> {
         let poll = self.poll_next_inner(cx);
-        self.baseline_metrics.record_poll(poll)
+        self.tracking_metrics.record_poll(poll)
     }
 }
 
@@ -606,7 +541,7 @@ impl SortPreservingMergeStream {
         loop {
             // NB timer records time taken on drop, so there are no
             // calls to `timer.done()` below.
-            let elapsed_compute = 
self.baseline_metrics.elapsed_compute().clone();
+            let elapsed_compute = 
self.tracking_metrics.elapsed_compute().clone();
             let _timer = elapsed_compute.timer();
 
             match self.min_heap.pop() {
@@ -630,7 +565,7 @@ impl SortPreservingMergeStream {
                         row_idx,
                     });
 
-                    if self.in_progress.len() == self.runtime.batch_size() {
+                    if self.in_progress.len() == self.batch_size {
                         return Poll::Ready(Some(self.build_record_batch()));
                     }
 
@@ -1263,7 +1198,7 @@ mod tests {
         }
 
         let metrics = ExecutionPlanMetricsSet::new();
-        let baseline_metrics = BaselineMetrics::new(&metrics, 0);
+        let tracking_metrics = MemTrackingMetrics::new(&metrics, 0);
 
         let merge_stream = SortPreservingMergeStream::new_from_receivers(
             receivers,
@@ -1271,8 +1206,7 @@ mod tests {
             AbortOnDropMany(vec![]),
             batches.schema(),
             sort.as_slice(),
-            baseline_metrics,
-            0,
+            tracking_metrics,
             runtime.clone(),
         );
 
diff --git a/datafusion/tests/provider_filter_pushdown.rs 
b/datafusion/tests/provider_filter_pushdown.rs
index 5a4f907..3aac5a8 100644
--- a/datafusion/tests/provider_filter_pushdown.rs
+++ b/datafusion/tests/provider_filter_pushdown.rs
@@ -25,7 +25,7 @@ use datafusion::execution::context::ExecutionContext;
 use datafusion::execution::runtime_env::RuntimeEnv;
 use datafusion::logical_plan::Expr;
 use datafusion::physical_plan::common::SizedRecordBatchStream;
-use datafusion::physical_plan::metrics::{BaselineMetrics, 
ExecutionPlanMetricsSet};
+use datafusion::physical_plan::metrics::{ExecutionPlanMetricsSet, 
MemTrackingMetrics};
 use datafusion::physical_plan::{
     DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, 
Statistics,
 };
@@ -86,11 +86,11 @@ impl ExecutionPlan for CustomPlan {
         _runtime: Arc<RuntimeEnv>,
     ) -> Result<SendableRecordBatchStream> {
         let metrics = ExecutionPlanMetricsSet::new();
-        let baseline_metrics = BaselineMetrics::new(&metrics, partition);
+        let tracking_metrics = MemTrackingMetrics::new(&metrics, partition);
         Ok(Box::pin(SizedRecordBatchStream::new(
             self.schema(),
             self.batches.clone(),
-            baseline_metrics,
+            tracking_metrics,
         )))
     }
 

Reply via email to