This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new a6dcd94305 Extract GroupValues (#6969) (#7016)
a6dcd94305 is described below

commit a6dcd943051a083693c352c6b4279156548490a0
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed Jul 19 17:56:41 2023 -0400

    Extract GroupValues (#6969) (#7016)
---
 .../core/src/physical_plan/aggregates/order/mod.rs |   2 +-
 .../core/src/physical_plan/aggregates/row_hash.rs  | 506 ++++++++++-----------
 datafusion/execution/src/memory_pool/mod.rs        |  43 --
 3 files changed, 245 insertions(+), 306 deletions(-)

diff --git a/datafusion/core/src/physical_plan/aggregates/order/mod.rs 
b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
index 81bf38aac3..ebe662c980 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
@@ -83,7 +83,7 @@ impl GroupOrdering {
     }
 
     /// remove the first n groups from the internal state, shifting
-    /// all existing indexes down by `n`. Returns stored hash values
+    /// all existing indexes down by `n`
     pub fn remove_groups(&mut self, n: usize) {
         match self {
             GroupOrdering::None => {}
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs 
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index 59ffbe5cf1..e3ac5c49a9 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -42,7 +42,7 @@ use arrow::array::*;
 use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
 use datafusion_common::Result;
 use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
-use datafusion_execution::memory_pool::{MemoryConsumer, MemoryDelta, 
MemoryReservation};
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
 use datafusion_execution::TaskContext;
 use hashbrown::raw::RawTable;
 
@@ -59,6 +59,181 @@ pub(crate) enum ExecutionState {
 use super::order::GroupOrdering;
 use super::AggregateExec;
 
+/// An interning store for group keys
+trait GroupValues: Send {
+    /// Calculates the `groups` for each input row of `cols`
+    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> 
Result<()>;
+
+    /// Returns the number of bytes used by this [`GroupValues`]
+    fn size(&self) -> usize;
+
+    /// Returns true if this [`GroupValues`] is empty
+    fn is_empty(&self) -> bool;
+
+    /// The number of values stored in this [`GroupValues`]
+    fn len(&self) -> usize;
+
+    /// Emits the group values
+    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
+}
+
+/// A [`GroupValues`] making use of [`Rows`]
+struct GroupValuesRows {
+    /// Converter for the group values
+    row_converter: RowConverter,
+
+    /// Logically maps group values to a group_index in
+    /// [`Self::group_values`] and in each accumulator
+    ///
+    /// Uses the raw API of hashbrown to avoid actually storing the
+    /// keys (group values) in the table
+    ///
+    /// keys: u64 hashes of the GroupValue
+    /// values: (hash, group_index)
+    map: RawTable<(u64, usize)>,
+
+    /// The size of `map` in bytes
+    map_size: usize,
+
+    /// The actual group by values, stored in arrow [`Row`] format.
+    /// `group_values[i]` holds the group value for group_index `i`.
+    ///
+    /// The row format is used to compare group keys quickly and store
+    /// them efficiently in memory. Quick comparison is especially
+    /// important for multi-column group keys.
+    ///
+    /// [`Row`]: arrow::row::Row
+    group_values: Rows,
+
+    // buffer to be reused to store hashes
+    hashes_buffer: Vec<u64>,
+
+    /// Random state for creating hashes
+    random_state: RandomState,
+}
+
+impl GroupValuesRows {
+    fn try_new(schema: SchemaRef) -> Result<Self> {
+        let row_converter = RowConverter::new(
+            schema
+                .fields()
+                .iter()
+                .map(|f| SortField::new(f.data_type().clone()))
+                .collect(),
+        )?;
+
+        let map = RawTable::with_capacity(0);
+        let group_values = row_converter.empty_rows(0, 0);
+
+        Ok(Self {
+            row_converter,
+            map,
+            map_size: 0,
+            group_values,
+            hashes_buffer: Default::default(),
+            random_state: Default::default(),
+        })
+    }
+}
+
+impl GroupValues for GroupValuesRows {
+    fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> 
Result<()> {
+        // Convert the group keys into the row format
+        // Avoid reallocation when 
https://github.com/apache/arrow-rs/issues/4479 is available
+        let group_rows = self.row_converter.convert_columns(cols)?;
+        let n_rows = group_rows.num_rows();
+
+        // tracks to which group each of the input rows belongs
+        groups.clear();
+
+        // 1.1 Calculate the group keys for the group values
+        let batch_hashes = &mut self.hashes_buffer;
+        batch_hashes.clear();
+        batch_hashes.resize(n_rows, 0);
+        create_hashes(cols, &self.random_state, batch_hashes)?;
+
+        for (row, &hash) in batch_hashes.iter().enumerate() {
+            let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
+                // verify that a group that we are inserting with hash is
+                // actually the same key value as the group in
+                // existing_idx  (aka group_values @ row)
+                group_rows.row(row) == self.group_values.row(*group_idx)
+            });
+
+            let group_idx = match entry {
+                // Existing group_index for this group value
+                Some((_hash, group_idx)) => *group_idx,
+                //  1.2 Need to create new entry for the group
+                None => {
+                    // Add new entry to aggr_state and save newly created index
+                    let group_idx = self.group_values.num_rows();
+                    self.group_values.push(group_rows.row(row));
+
+                    // for hasher function, use precomputed hash value
+                    self.map.insert_accounted(
+                        (hash, group_idx),
+                        |(hash, _group_index)| *hash,
+                        &mut self.map_size,
+                    );
+                    group_idx
+                }
+            };
+            groups.push(group_idx);
+        }
+
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        self.row_converter.size()
+            + self.group_values.size()
+            + self.map_size
+            + self.hashes_buffer.allocated_size()
+    }
+
+    fn is_empty(&self) -> bool {
+        self.len() == 0
+    }
+
+    fn len(&self) -> usize {
+        self.group_values.num_rows()
+    }
+
+    fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        Ok(match emit_to {
+            EmitTo::All => {
+                // Eventually we may also want to clear the hash table here
+                self.row_converter.convert_rows(&self.group_values)?
+            }
+            EmitTo::First(n) => {
+                let groups_rows = self.group_values.iter().take(n);
+                let output = self.row_converter.convert_rows(groups_rows)?;
+                // Clear out first n group keys by copying them to a new Rows.
+                // TODO file some ticket in arrow-rs to make this more 
efficent?
+                let mut new_group_values = self.row_converter.empty_rows(0, 0);
+                for row in self.group_values.iter().skip(n) {
+                    new_group_values.push(row);
+                }
+                std::mem::swap(&mut new_group_values, &mut self.group_values);
+
+                // SAFETY: self.map outlives iterator and is not modified 
concurrently
+                unsafe {
+                    for bucket in self.map.iter() {
+                        // Decrement group index by n
+                        match bucket.as_ref().1.checked_sub(n) {
+                            // Group index was >= n, shift value down
+                            Some(sub) => bucket.as_mut().1 = sub,
+                            // Group index was < n, so remove from table
+                            None => self.map.erase(bucket),
+                        }
+                    }
+                }
+                output
+            }
+        })
+    }
+}
+
 /// Hash based Grouping Aggregator
 ///
 /// # Design Goals
@@ -74,29 +249,29 @@ use super::AggregateExec;
 ///
 /// ```text
 ///
-/// stores "group       stores group values,       internally stores aggregate
-///    indexes"          in arrow_row format         values, for all groups
+///     Assigns a consecutive group           internally stores aggregate 
values
+///     index for each unique set                     for all groups
+///         of group values
 ///
-/// ┌─────────────┐      ┌────────────┐    ┌──────────────┐       
┌──────────────┐
-/// │   ┌─────┐   │      │ ┌────────┐ │    │┌────────────┐│       
│┌────────────┐│
-/// │   │  5  │   │ ┌────┼▶│  "A"   │ │    ││accumulator ││       
││accumulator ││
-/// │   ├─────┤   │ │    │ ├────────┤ │    ││     0      ││       ││     N     
 ││
-/// │   │  9  │   │ │    │ │  "Z"   │ │    ││ ┌────────┐ ││       ││ 
┌────────┐ ││
-/// │   └─────┘   │ │    │ └────────┘ │    ││ │ state  │ ││       ││ │ state  
│ ││
-/// │     ...     │ │    │            │    ││ │┌─────┐ │ ││  ...  ││ │┌─────┐ 
│ ││
-/// │   ┌─────┐   │ │    │    ...     │    ││ │├─────┤ │ ││       ││ │├─────┤ 
│ ││
-/// │   │  1  │───┼─┘    │            │    ││ │└─────┘ │ ││       ││ │└─────┘ 
│ ││
-/// │   ├─────┤   │      │            │    ││ │        │ ││       ││ │        
│ ││
-/// │   │ 13  │───┼─┐    │ ┌────────┐ │    ││ │  ...   │ ││       ││ │  ...   
│ ││
-/// │   └─────┘   │ └────┼▶│  "Q"   │ │    ││ │        │ ││       ││ │        
│ ││
-/// └─────────────┘      │ └────────┘ │    ││ │┌─────┐ │ ││       ││ │┌─────┐ 
│ ││
-///                      │            │    ││ │└─────┘ │ ││       ││ │└─────┘ 
│ ││
-///                      └────────────┘    ││ └────────┘ ││       ││ 
└────────┘ ││
-///                                        │└────────────┘│       
│└────────────┘│
-///                                        └──────────────┘       
└──────────────┘
+///         ┌────────────┐              ┌──────────────┐       ┌──────────────┐
+///         │ ┌────────┐ │              │┌────────────┐│       │┌────────────┐│
+///         │ │  "A"   │ │              ││accumulator ││       ││accumulator ││
+///         │ ├────────┤ │              ││     0      ││       ││     N      ││
+///         │ │  "Z"   │ │              ││ ┌────────┐ ││       ││ ┌────────┐ ││
+///         │ └────────┘ │              ││ │ state  │ ││       ││ │ state  │ ││
+///         │            │              ││ │┌─────┐ │ ││  ...  ││ │┌─────┐ │ ││
+///         │    ...     │              ││ │├─────┤ │ ││       ││ │├─────┤ │ ││
+///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
+///         │            │              ││ │        │ ││       ││ │        │ ││
+///         │ ┌────────┐ │              ││ │  ...   │ ││       ││ │  ...   │ ││
+///         │ │  "Q"   │ │              ││ │        │ ││       ││ │        │ ││
+///         │ └────────┘ │              ││ │┌─────┐ │ ││       ││ │┌─────┐ │ ││
+///         │            │              ││ │└─────┘ │ ││       ││ │└─────┘ │ ││
+///         └────────────┘              ││ └────────┘ ││       ││ └────────┘ ││
+///                                     │└────────────┘│       │└────────────┘│
+///                                     └──────────────┘       └──────────────┘
 ///
-///       map            group_values                   accumulators
-///  (Hash Table)
+///         group_values                             accumulators
 ///
 ///  ```
 ///
@@ -108,10 +283,10 @@ use super::AggregateExec;
 ///
 /// # Description
 ///
-/// The hash table does not store any aggregate state inline. It only
-/// stores "group indices", one for each (distinct) group value. The
+/// [`group_values`] does not store any aggregate state inline. It only
+/// assigns "group indices", one for each (distinct) group value. The
 /// accumulators manage the in-progress aggregate state for each
-/// group, and the group values themselves are stored in
+/// group, with the group values themselves are stored in
 /// [`group_values`] at the corresponding group index.
 ///
 /// The accumulator state (e.g partial sums) is managed by and stored
@@ -152,40 +327,18 @@ pub(crate) struct GroupedHashAggregateStream {
     /// the filter expression is  `x > 100`.
     filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
 
-    /// Converter for the group values
-    row_converter: RowConverter,
-
     /// GROUP BY expressions
     group_by: PhysicalGroupBy,
 
     /// The memory reservation for this grouping
     reservation: MemoryReservation,
 
-    /// Logically maps group values to a group_index in
-    /// [`Self::group_values`] and in each accumulator
-    ///
-    /// Uses the raw API of hashbrown to avoid actually storing the
-    /// keys (group values) in the table
-    ///
-    /// keys: u64 hashes of the GroupValue
-    /// values: (hash, group_index)
-    map: RawTable<(u64, usize)>,
-
-    /// The actual group by values, stored in arrow [`Row`] format.
-    /// `group_values[i]` holds the group value for group_index `i`.
-    ///
-    /// The row format is used to compare group keys quickly and store
-    /// them efficiently in memory. Quick comparison is especially
-    /// important for multi-column group keys.
-    ///
-    /// [`Row`]: arrow::row::Row
-    group_values: Rows,
+    /// An interning store of group keys
+    group_values: Box<dyn GroupValues>,
 
     /// scratch space for the current input [`RecordBatch`] being
-    /// processed. The reason this is a field is so it can be reused
-    /// for all input batches, avoiding the need to reallocate Vecs on
-    /// each input.
-    scratch_space: ScratchSpace,
+    /// processed. Reused across batches here to avoid reallocations
+    current_group_indices: Vec<usize>,
 
     /// Tracks if this stream is generating input or output
     exec_state: ExecutionState,
@@ -193,9 +346,6 @@ pub(crate) struct GroupedHashAggregateStream {
     /// Execution metrics
     baseline_metrics: BaselineMetrics,
 
-    /// Random state for creating hashes
-    random_state: RandomState,
-
     /// max rows in output RecordBatches
     batch_size: usize,
 
@@ -252,18 +402,9 @@ impl GroupedHashAggregateStream {
             .collect::<Result<_>>()?;
 
         let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
-        let row_converter = RowConverter::new(
-            group_schema
-                .fields()
-                .iter()
-                .map(|f| SortField::new(f.data_type().clone()))
-                .collect(),
-        )?;
 
         let name = format!("GroupedHashAggregateStream[{partition}]");
         let reservation = 
MemoryConsumer::new(name).register(context.memory_pool());
-        let map = RawTable::with_capacity(0);
-        let group_values = row_converter.empty_rows(0, 0);
 
         let group_ordering = agg
             .aggregation_ordering
@@ -275,6 +416,8 @@ impl GroupedHashAggregateStream {
             .transpose()?
             .unwrap_or(GroupOrdering::None);
 
+        let group = Box::new(GroupValuesRows::try_new(group_schema)?);
+
         timer.done();
 
         let exec_state = ExecutionState::ReadingInput;
@@ -286,15 +429,12 @@ impl GroupedHashAggregateStream {
             accumulators,
             aggregate_arguments,
             filter_expressions,
-            row_converter,
             group_by: agg_group_by,
             reservation,
-            map,
-            group_values,
-            scratch_space: ScratchSpace::new(),
+            group_values: group,
+            current_group_indices: Default::default(),
             exec_state,
             baseline_metrics,
-            random_state: Default::default(),
             batch_size,
             group_ordering,
             input_done: false,
@@ -355,11 +495,9 @@ impl Stream for GroupedHashAggregateStream {
                             // If we can begin emitting rows, do so,
                             // otherwise keep consuming input
                             assert!(!self.input_done);
-                            let to_emit = self.group_ordering.emit_to();
 
-                            if let Some(to_emit) = to_emit {
-                                let batch =
-                                    
extract_ok!(self.create_batch_from_map(to_emit));
+                            if let Some(to_emit) = 
self.group_ordering.emit_to() {
+                                let batch = extract_ok!(self.emit(to_emit));
                                 self.exec_state = 
ExecutionState::ProducingOutput(batch);
                             }
                             timer.done();
@@ -373,8 +511,7 @@ impl Stream for GroupedHashAggregateStream {
                             self.input_done = true;
                             self.group_ordering.input_done();
                             let timer = elapsed_compute.timer();
-                            let batch =
-                                
extract_ok!(self.create_batch_from_map(EmitTo::All));
+                            let batch = extract_ok!(self.emit(EmitTo::All));
                             self.exec_state = 
ExecutionState::ProducingOutput(batch);
                             timer.done();
                         }
@@ -415,95 +552,11 @@ impl RecordBatchStream for GroupedHashAggregateStream {
 }
 
 impl GroupedHashAggregateStream {
-    /// Calculates the group indices for each input row of
-    /// `group_values`.
-    ///
-    /// At the return of this function,
-    /// `self.scratch_space.current_group_indices` has the same number
-    /// of entries as each array in `group_values` and holds the
-    /// correct group_index for that row.
-    ///
-    /// This is one of the core hot loops in the algorithm
-    fn update_group_state(
-        &mut self,
-        group_values: &[ArrayRef],
-        memory_delta: &mut MemoryDelta,
-    ) -> Result<()> {
-        // Convert the group keys into the row format
-        // Avoid reallocation when 
https://github.com/apache/arrow-rs/issues/4479 is available
-        let group_rows = self.row_converter.convert_columns(group_values)?;
-        let n_rows = group_rows.num_rows();
-
-        // track memory used
-        memory_delta.dec(self.state_size());
-
-        // tracks to which group each of the input rows belongs
-        let group_indices = &mut self.scratch_space.current_group_indices;
-        group_indices.clear();
-
-        // 1.1 Calculate the group keys for the group values
-        let batch_hashes = &mut self.scratch_space.hashes_buffer;
-        batch_hashes.clear();
-        batch_hashes.resize(n_rows, 0);
-        create_hashes(group_values, &self.random_state, batch_hashes)?;
-
-        let mut allocated = 0;
-        let starting_num_groups = self.group_values.num_rows();
-        for (row, &hash) in batch_hashes.iter().enumerate() {
-            let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
-                // verify that a group that we are inserting with hash is
-                // actually the same key value as the group in
-                // existing_idx  (aka group_values @ row)
-                group_rows.row(row) == self.group_values.row(*group_idx)
-            });
-
-            let group_idx = match entry {
-                // Existing group_index for this group value
-                Some((_hash, group_idx)) => *group_idx,
-                //  1.2 Need to create new entry for the group
-                None => {
-                    // Add new entry to aggr_state and save newly created index
-                    let group_idx = self.group_values.num_rows();
-                    self.group_values.push(group_rows.row(row));
-
-                    // for hasher function, use precomputed hash value
-                    self.map.insert_accounted(
-                        (hash, group_idx),
-                        |(hash, _group_index)| *hash,
-                        &mut allocated,
-                    );
-                    group_idx
-                }
-            };
-            group_indices.push(group_idx);
-        }
-        memory_delta.inc(allocated);
-
-        // Update ordering information if necessary
-        let total_num_groups = self.group_values.num_rows();
-        if total_num_groups > starting_num_groups {
-            self.group_ordering.new_groups(
-                group_values,
-                group_indices,
-                total_num_groups,
-            )?;
-        }
-
-        // account for memory change
-        memory_delta.inc(self.state_size());
-
-        Ok(())
-    }
-
     /// Perform group-by aggregation for the given [`RecordBatch`].
     fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
         // Evaluate the grouping expressions
         let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
 
-        // Keep track of memory allocated:
-        let mut memory_delta = MemoryDelta::new();
-        memory_delta.dec(self.state_size());
-
         // Evaluate the aggregation expressions.
         let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
 
@@ -512,8 +565,20 @@ impl GroupedHashAggregateStream {
 
         for group_values in &group_by_values {
             // calculate the group indices for each input row
-            self.update_group_state(group_values, &mut memory_delta)?;
-            let group_indices = &self.scratch_space.current_group_indices;
+            let starting_num_groups = self.group_values.len();
+            self.group_values
+                .intern(group_values, &mut self.current_group_indices)?;
+            let group_indices = &self.current_group_indices;
+
+            // Update ordering information if necessary
+            let total_num_groups = self.group_values.len();
+            if total_num_groups > starting_num_groups {
+                self.group_ordering.new_groups(
+                    group_values,
+                    group_indices,
+                    total_num_groups,
+                )?;
+            }
 
             // Gather the inputs to call the actual accumulator
             let t = self
@@ -522,10 +587,7 @@ impl GroupedHashAggregateStream {
                 .zip(input_values.iter())
                 .zip(filter_values.iter());
 
-            let total_num_groups = self.group_values.num_rows();
-
             for ((acc, values), opt_filter) in t {
-                memory_delta.dec(acc.size());
                 let opt_filter = opt_filter.as_ref().map(|filter| 
filter.as_boolean());
 
                 // Call the appropriate method on each aggregator with
@@ -552,42 +614,32 @@ impl GroupedHashAggregateStream {
                         )?;
                     }
                 }
-                memory_delta.inc(acc.size());
             }
         }
-        memory_delta.inc(self.state_size());
 
-        // Update allocation AFTER it is used, simplifying accounting,
-        // though it results in a temporary overshoot.
-        memory_delta.update(&mut self.reservation)
+        self.update_memory_reservation()
+    }
+
+    fn update_memory_reservation(&mut self) -> Result<()> {
+        let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
+        self.reservation.try_resize(
+            acc + self.group_values.size()
+                + self.group_ordering.size()
+                + self.current_group_indices.allocated_size(),
+        )
     }
 
     /// Create an output RecordBatch with the group keys and
     /// accumulator states/values specified in emit_to
-    fn create_batch_from_map(&mut self, emit_to: EmitTo) -> 
Result<RecordBatch> {
-        if self.group_values.num_rows() == 0 {
+    fn emit(&mut self, emit_to: EmitTo) -> Result<RecordBatch> {
+        if self.group_values.is_empty() {
             return Ok(RecordBatch::new_empty(self.schema()));
         }
 
-        let output = self.build_output(emit_to)?;
-        self.remove_emitted(emit_to)?;
-        let batch = RecordBatch::try_new(self.schema(), output)?;
-        Ok(batch)
-    }
-
-    /// Creates output: `(group 1, group 2, ... agg 1, agg 2, ...)`
-    fn build_output(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
-        // First output rows are the groups
-        let mut output: Vec<ArrayRef> = match emit_to {
-            EmitTo::All => {
-                let groups_rows = self.group_values.iter();
-                self.row_converter.convert_rows(groups_rows)?
-            }
-            EmitTo::First(n) => {
-                let groups_rows = self.group_values.iter().take(n);
-                self.row_converter.convert_rows(groups_rows)?
-            }
-        };
+        let mut output = self.group_values.emit(emit_to)?;
+        if let EmitTo::First(n) = emit_to {
+            self.group_ordering.remove_groups(n);
+        }
 
         // Next output each aggregate value
         for acc in self.accumulators.iter_mut() {
@@ -600,78 +652,8 @@ impl GroupedHashAggregateStream {
             }
         }
 
-        Ok(output)
-    }
-
-    /// Removes the first `n` groups, adjusting all group_indices
-    /// appropriately
-    fn remove_emitted(&mut self, emit_to: EmitTo) -> Result<()> {
-        let mut memory_delta = MemoryDelta::new();
-        memory_delta.dec(self.state_size());
-
-        match emit_to {
-            EmitTo::All => {
-                // Eventually we may also want to clear the hash table here
-                //self.map.clear();
-            }
-            EmitTo::First(n) => {
-                // Clear out first n group keys by copying them to a new Rows.
-                // TODO file some ticket in arrow-rs to make this more 
efficent?
-                let mut new_group_values = self.row_converter.empty_rows(0, 0);
-                for row in self.group_values.iter().skip(n) {
-                    new_group_values.push(row);
-                }
-                std::mem::swap(&mut new_group_values, &mut self.group_values);
-
-                self.group_ordering.remove_groups(n);
-                // SAFETY: self.map outlives iterator and is not modified 
concurrently
-                unsafe {
-                    for bucket in self.map.iter() {
-                        // Decrement group index by n
-                        match bucket.as_ref().1.checked_sub(n) {
-                            // Group index was >= n, shift value down
-                            Some(sub) => bucket.as_mut().1 = sub,
-                            // Group index was < n, so remove from table
-                            None => self.map.erase(bucket),
-                        }
-                    }
-                }
-            }
-        };
-        // account for memory change
-        memory_delta.inc(self.state_size());
-        memory_delta.update(&mut self.reservation)
-    }
-
-    /// return the current size stored by variable state in this structure
-    fn state_size(&self) -> usize {
-        self.group_values.size()
-            + self.scratch_space.size()
-            + self.group_ordering.size()
-            + self.row_converter.size()
-    }
-}
-
-/// Holds structures used for the current input [`RecordBatch`] being
-/// processed. Reused across batches here to avoid reallocations
-#[derive(Debug, Default)]
-struct ScratchSpace {
-    /// scratch space for the current input [`RecordBatch`] being
-    /// processed. Reused across batches here to avoid reallocations
-    current_group_indices: Vec<usize>,
-    // buffer to be reused to store hashes
-    hashes_buffer: Vec<u64>,
-}
-
-impl ScratchSpace {
-    fn new() -> Self {
-        Default::default()
-    }
-
-    /// Return the amount of memory alocated by this structure in bytes
-    fn size(&self) -> usize {
-        std::mem::size_of_val(self)
-            + self.current_group_indices.allocated_size()
-            + self.hashes_buffer.allocated_size()
+        self.update_memory_reservation()?;
+        let batch = RecordBatch::try_new(self.schema(), output)?;
+        Ok(batch)
     }
 }
diff --git a/datafusion/execution/src/memory_pool/mod.rs 
b/datafusion/execution/src/memory_pool/mod.rs
index fe077524a4..011cd72cbb 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -219,49 +219,6 @@ pub fn human_readable_size(size: usize) -> String {
     format!("{value:.1} {unit}")
 }
 
-/// Tracks the change in memory to avoid overflow. Typically, this
-/// is isued like the following
-///
-/// 1. Call `delta.dec(sized_thing.size())`
-///
-/// 2. potentially change size of `sized_thing`
-///
-/// 3. Call `delta.inc(size_thing.size())`
-#[derive(Debug, Default)]
-pub struct MemoryDelta {
-    decrease: usize,
-    increase: usize,
-}
-
-impl MemoryDelta {
-    pub fn new() -> Self {
-        Default::default()
-    }
-
-    /// record size being 'decremented'. This is used for to record the
-    /// initial size of some allocation prior to hange
-    pub fn dec(&mut self, sz: usize) {
-        self.decrease += sz;
-    }
-
-    /// record size being 'incremented'. This is used for to record
-    /// the final size of some object.
-    pub fn inc(&mut self, sz: usize) {
-        self.increase += sz;
-    }
-
-    /// Adjusts the reservation with the delta used / freed
-    pub fn update(self, reservation: &mut MemoryReservation) -> Result<()> {
-        let Self { decrease, increase } = self;
-        match increase.cmp(&decrease) {
-            Ordering::Less => reservation.shrink(decrease - increase),
-            Ordering::Equal => {}
-            Ordering::Greater => reservation.try_grow(increase - decrease)?,
-        };
-        Ok(())
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;

Reply via email to