This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new a6dcd94305 Extract GroupValues (#6969) (#7016)
a6dcd94305 is described below
commit a6dcd943051a083693c352c6b4279156548490a0
Author: Raphael Taylor-Davies <[email protected]>
AuthorDate: Wed Jul 19 17:56:41 2023 -0400
Extract GroupValues (#6969) (#7016)
---
.../core/src/physical_plan/aggregates/order/mod.rs | 2 +-
.../core/src/physical_plan/aggregates/row_hash.rs | 506 ++++++++++-----------
datafusion/execution/src/memory_pool/mod.rs | 43 --
3 files changed, 245 insertions(+), 306 deletions(-)
diff --git a/datafusion/core/src/physical_plan/aggregates/order/mod.rs
b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
index 81bf38aac3..ebe662c980 100644
--- a/datafusion/core/src/physical_plan/aggregates/order/mod.rs
+++ b/datafusion/core/src/physical_plan/aggregates/order/mod.rs
@@ -83,7 +83,7 @@ impl GroupOrdering {
}
/// remove the first n groups from the internal state, shifting
- /// all existing indexes down by `n`. Returns stored hash values
+ /// all existing indexes down by `n`
pub fn remove_groups(&mut self, n: usize) {
match self {
GroupOrdering::None => {}
diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
index 59ffbe5cf1..e3ac5c49a9 100644
--- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs
+++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs
@@ -42,7 +42,7 @@ use arrow::array::*;
use arrow::{datatypes::SchemaRef, record_batch::RecordBatch};
use datafusion_common::Result;
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
-use datafusion_execution::memory_pool::{MemoryConsumer, MemoryDelta,
MemoryReservation};
+use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::TaskContext;
use hashbrown::raw::RawTable;
@@ -59,6 +59,181 @@ pub(crate) enum ExecutionState {
use super::order::GroupOrdering;
use super::AggregateExec;
+/// An interning store for group keys
+trait GroupValues: Send {
+ /// Calculates the `groups` for each input row of `cols`
+ fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) ->
Result<()>;
+
+ /// Returns the number of bytes used by this [`GroupValues`]
+ fn size(&self) -> usize;
+
+ /// Returns true if this [`GroupValues`] is empty
+ fn is_empty(&self) -> bool;
+
+ /// The number of values stored in this [`GroupValues`]
+ fn len(&self) -> usize;
+
+ /// Emits the group values
+ fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>>;
+}
+
+/// A [`GroupValues`] making use of [`Rows`]
+struct GroupValuesRows {
+ /// Converter for the group values
+ row_converter: RowConverter,
+
+ /// Logically maps group values to a group_index in
+ /// [`Self::group_values`] and in each accumulator
+ ///
+ /// Uses the raw API of hashbrown to avoid actually storing the
+ /// keys (group values) in the table
+ ///
+ /// keys: u64 hashes of the GroupValue
+ /// values: (hash, group_index)
+ map: RawTable<(u64, usize)>,
+
+ /// The size of `map` in bytes
+ map_size: usize,
+
+ /// The actual group by values, stored in arrow [`Row`] format.
+ /// `group_values[i]` holds the group value for group_index `i`.
+ ///
+ /// The row format is used to compare group keys quickly and store
+ /// them efficiently in memory. Quick comparison is especially
+ /// important for multi-column group keys.
+ ///
+ /// [`Row`]: arrow::row::Row
+ group_values: Rows,
+
+ // buffer to be reused to store hashes
+ hashes_buffer: Vec<u64>,
+
+ /// Random state for creating hashes
+ random_state: RandomState,
+}
+
+impl GroupValuesRows {
+ fn try_new(schema: SchemaRef) -> Result<Self> {
+ let row_converter = RowConverter::new(
+ schema
+ .fields()
+ .iter()
+ .map(|f| SortField::new(f.data_type().clone()))
+ .collect(),
+ )?;
+
+ let map = RawTable::with_capacity(0);
+ let group_values = row_converter.empty_rows(0, 0);
+
+ Ok(Self {
+ row_converter,
+ map,
+ map_size: 0,
+ group_values,
+ hashes_buffer: Default::default(),
+ random_state: Default::default(),
+ })
+ }
+}
+
+impl GroupValues for GroupValuesRows {
+ fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) ->
Result<()> {
+ // Convert the group keys into the row format
+ // Avoid reallocation when
https://github.com/apache/arrow-rs/issues/4479 is available
+ let group_rows = self.row_converter.convert_columns(cols)?;
+ let n_rows = group_rows.num_rows();
+
+ // tracks to which group each of the input rows belongs
+ groups.clear();
+
+ // 1.1 Calculate the group keys for the group values
+ let batch_hashes = &mut self.hashes_buffer;
+ batch_hashes.clear();
+ batch_hashes.resize(n_rows, 0);
+ create_hashes(cols, &self.random_state, batch_hashes)?;
+
+ for (row, &hash) in batch_hashes.iter().enumerate() {
+ let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
+ // verify that a group that we are inserting with hash is
+ // actually the same key value as the group in
+ // existing_idx (aka group_values @ row)
+ group_rows.row(row) == self.group_values.row(*group_idx)
+ });
+
+ let group_idx = match entry {
+ // Existing group_index for this group value
+ Some((_hash, group_idx)) => *group_idx,
+ // 1.2 Need to create new entry for the group
+ None => {
+ // Add new entry to aggr_state and save newly created index
+ let group_idx = self.group_values.num_rows();
+ self.group_values.push(group_rows.row(row));
+
+ // for hasher function, use precomputed hash value
+ self.map.insert_accounted(
+ (hash, group_idx),
+ |(hash, _group_index)| *hash,
+ &mut self.map_size,
+ );
+ group_idx
+ }
+ };
+ groups.push(group_idx);
+ }
+
+ Ok(())
+ }
+
+ fn size(&self) -> usize {
+ self.row_converter.size()
+ + self.group_values.size()
+ + self.map_size
+ + self.hashes_buffer.allocated_size()
+ }
+
+ fn is_empty(&self) -> bool {
+ self.len() == 0
+ }
+
+ fn len(&self) -> usize {
+ self.group_values.num_rows()
+ }
+
+ fn emit(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+ Ok(match emit_to {
+ EmitTo::All => {
+ // Eventually we may also want to clear the hash table here
+ self.row_converter.convert_rows(&self.group_values)?
+ }
+ EmitTo::First(n) => {
+ let groups_rows = self.group_values.iter().take(n);
+ let output = self.row_converter.convert_rows(groups_rows)?;
+ // Clear out first n group keys by copying them to a new Rows.
+ // TODO file some ticket in arrow-rs to make this more
efficent?
+ let mut new_group_values = self.row_converter.empty_rows(0, 0);
+ for row in self.group_values.iter().skip(n) {
+ new_group_values.push(row);
+ }
+ std::mem::swap(&mut new_group_values, &mut self.group_values);
+
+ // SAFETY: self.map outlives iterator and is not modified
concurrently
+ unsafe {
+ for bucket in self.map.iter() {
+ // Decrement group index by n
+ match bucket.as_ref().1.checked_sub(n) {
+ // Group index was >= n, shift value down
+ Some(sub) => bucket.as_mut().1 = sub,
+ // Group index was < n, so remove from table
+ None => self.map.erase(bucket),
+ }
+ }
+ }
+ output
+ }
+ })
+ }
+}
+
/// Hash based Grouping Aggregator
///
/// # Design Goals
@@ -74,29 +249,29 @@ use super::AggregateExec;
///
/// ```text
///
-/// stores "group stores group values, internally stores aggregate
-/// indexes" in arrow_row format values, for all groups
+/// Assigns a consecutive group internally stores aggregate
values
+/// index for each unique set for all groups
+/// of group values
///
-/// ┌─────────────┐ ┌────────────┐ ┌──────────────┐
┌──────────────┐
-/// │ ┌─────┐ │ │ ┌────────┐ │ │┌────────────┐│
│┌────────────┐│
-/// │ │ 5 │ │ ┌────┼▶│ "A" │ │ ││accumulator ││
││accumulator ││
-/// │ ├─────┤ │ │ │ ├────────┤ │ ││ 0 ││ ││ N
││
-/// │ │ 9 │ │ │ │ │ "Z" │ │ ││ ┌────────┐ ││ ││
┌────────┐ ││
-/// │ └─────┘ │ │ │ └────────┘ │ ││ │ state │ ││ ││ │ state
│ ││
-/// │ ... │ │ │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐
│ ││
-/// │ ┌─────┐ │ │ │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤
│ ││
-/// │ │ 1 │───┼─┘ │ │ ││ │└─────┘ │ ││ ││ │└─────┘
│ ││
-/// │ ├─────┤ │ │ │ ││ │ │ ││ ││ │
│ ││
-/// │ │ 13 │───┼─┐ │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ...
│ ││
-/// │ └─────┘ │ └────┼▶│ "Q" │ │ ││ │ │ ││ ││ │
│ ││
-/// └─────────────┘ │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐
│ ││
-/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘
│ ││
-/// └────────────┘ ││ └────────┘ ││ ││
└────────┘ ││
-/// │└────────────┘│
│└────────────┘│
-/// └──────────────┘
└──────────────┘
+/// ┌────────────┐ ┌──────────────┐ ┌──────────────┐
+/// │ ┌────────┐ │ │┌────────────┐│ │┌────────────┐│
+/// │ │ "A" │ │ ││accumulator ││ ││accumulator ││
+/// │ ├────────┤ │ ││ 0 ││ ││ N ││
+/// │ │ "Z" │ │ ││ ┌────────┐ ││ ││ ┌────────┐ ││
+/// │ └────────┘ │ ││ │ state │ ││ ││ │ state │ ││
+/// │ │ ││ │┌─────┐ │ ││ ... ││ │┌─────┐ │ ││
+/// │ ... │ ││ │├─────┤ │ ││ ││ │├─────┤ │ ││
+/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││
+/// │ │ ││ │ │ ││ ││ │ │ ││
+/// │ ┌────────┐ │ ││ │ ... │ ││ ││ │ ... │ ││
+/// │ │ "Q" │ │ ││ │ │ ││ ││ │ │ ││
+/// │ └────────┘ │ ││ │┌─────┐ │ ││ ││ │┌─────┐ │ ││
+/// │ │ ││ │└─────┘ │ ││ ││ │└─────┘ │ ││
+/// └────────────┘ ││ └────────┘ ││ ││ └────────┘ ││
+/// │└────────────┘│ │└────────────┘│
+/// └──────────────┘ └──────────────┘
///
-/// map group_values accumulators
-/// (Hash Table)
+/// group_values accumulators
///
/// ```
///
@@ -108,10 +283,10 @@ use super::AggregateExec;
///
/// # Description
///
-/// The hash table does not store any aggregate state inline. It only
-/// stores "group indices", one for each (distinct) group value. The
+/// [`group_values`] does not store any aggregate state inline. It only
+/// assigns "group indices", one for each (distinct) group value. The
/// accumulators manage the in-progress aggregate state for each
-/// group, and the group values themselves are stored in
+/// group, with the group values themselves are stored in
/// [`group_values`] at the corresponding group index.
///
/// The accumulator state (e.g partial sums) is managed by and stored
@@ -152,40 +327,18 @@ pub(crate) struct GroupedHashAggregateStream {
/// the filter expression is `x > 100`.
filter_expressions: Vec<Option<Arc<dyn PhysicalExpr>>>,
- /// Converter for the group values
- row_converter: RowConverter,
-
/// GROUP BY expressions
group_by: PhysicalGroupBy,
/// The memory reservation for this grouping
reservation: MemoryReservation,
- /// Logically maps group values to a group_index in
- /// [`Self::group_values`] and in each accumulator
- ///
- /// Uses the raw API of hashbrown to avoid actually storing the
- /// keys (group values) in the table
- ///
- /// keys: u64 hashes of the GroupValue
- /// values: (hash, group_index)
- map: RawTable<(u64, usize)>,
-
- /// The actual group by values, stored in arrow [`Row`] format.
- /// `group_values[i]` holds the group value for group_index `i`.
- ///
- /// The row format is used to compare group keys quickly and store
- /// them efficiently in memory. Quick comparison is especially
- /// important for multi-column group keys.
- ///
- /// [`Row`]: arrow::row::Row
- group_values: Rows,
+ /// An interning store of group keys
+ group_values: Box<dyn GroupValues>,
/// scratch space for the current input [`RecordBatch`] being
- /// processed. The reason this is a field is so it can be reused
- /// for all input batches, avoiding the need to reallocate Vecs on
- /// each input.
- scratch_space: ScratchSpace,
+ /// processed. Reused across batches here to avoid reallocations
+ current_group_indices: Vec<usize>,
/// Tracks if this stream is generating input or output
exec_state: ExecutionState,
@@ -193,9 +346,6 @@ pub(crate) struct GroupedHashAggregateStream {
/// Execution metrics
baseline_metrics: BaselineMetrics,
- /// Random state for creating hashes
- random_state: RandomState,
-
/// max rows in output RecordBatches
batch_size: usize,
@@ -252,18 +402,9 @@ impl GroupedHashAggregateStream {
.collect::<Result<_>>()?;
let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
- let row_converter = RowConverter::new(
- group_schema
- .fields()
- .iter()
- .map(|f| SortField::new(f.data_type().clone()))
- .collect(),
- )?;
let name = format!("GroupedHashAggregateStream[{partition}]");
let reservation =
MemoryConsumer::new(name).register(context.memory_pool());
- let map = RawTable::with_capacity(0);
- let group_values = row_converter.empty_rows(0, 0);
let group_ordering = agg
.aggregation_ordering
@@ -275,6 +416,8 @@ impl GroupedHashAggregateStream {
.transpose()?
.unwrap_or(GroupOrdering::None);
+ let group = Box::new(GroupValuesRows::try_new(group_schema)?);
+
timer.done();
let exec_state = ExecutionState::ReadingInput;
@@ -286,15 +429,12 @@ impl GroupedHashAggregateStream {
accumulators,
aggregate_arguments,
filter_expressions,
- row_converter,
group_by: agg_group_by,
reservation,
- map,
- group_values,
- scratch_space: ScratchSpace::new(),
+ group_values: group,
+ current_group_indices: Default::default(),
exec_state,
baseline_metrics,
- random_state: Default::default(),
batch_size,
group_ordering,
input_done: false,
@@ -355,11 +495,9 @@ impl Stream for GroupedHashAggregateStream {
// If we can begin emitting rows, do so,
// otherwise keep consuming input
assert!(!self.input_done);
- let to_emit = self.group_ordering.emit_to();
- if let Some(to_emit) = to_emit {
- let batch =
-
extract_ok!(self.create_batch_from_map(to_emit));
+ if let Some(to_emit) =
self.group_ordering.emit_to() {
+ let batch = extract_ok!(self.emit(to_emit));
self.exec_state =
ExecutionState::ProducingOutput(batch);
}
timer.done();
@@ -373,8 +511,7 @@ impl Stream for GroupedHashAggregateStream {
self.input_done = true;
self.group_ordering.input_done();
let timer = elapsed_compute.timer();
- let batch =
-
extract_ok!(self.create_batch_from_map(EmitTo::All));
+ let batch = extract_ok!(self.emit(EmitTo::All));
self.exec_state =
ExecutionState::ProducingOutput(batch);
timer.done();
}
@@ -415,95 +552,11 @@ impl RecordBatchStream for GroupedHashAggregateStream {
}
impl GroupedHashAggregateStream {
- /// Calculates the group indices for each input row of
- /// `group_values`.
- ///
- /// At the return of this function,
- /// `self.scratch_space.current_group_indices` has the same number
- /// of entries as each array in `group_values` and holds the
- /// correct group_index for that row.
- ///
- /// This is one of the core hot loops in the algorithm
- fn update_group_state(
- &mut self,
- group_values: &[ArrayRef],
- memory_delta: &mut MemoryDelta,
- ) -> Result<()> {
- // Convert the group keys into the row format
- // Avoid reallocation when
https://github.com/apache/arrow-rs/issues/4479 is available
- let group_rows = self.row_converter.convert_columns(group_values)?;
- let n_rows = group_rows.num_rows();
-
- // track memory used
- memory_delta.dec(self.state_size());
-
- // tracks to which group each of the input rows belongs
- let group_indices = &mut self.scratch_space.current_group_indices;
- group_indices.clear();
-
- // 1.1 Calculate the group keys for the group values
- let batch_hashes = &mut self.scratch_space.hashes_buffer;
- batch_hashes.clear();
- batch_hashes.resize(n_rows, 0);
- create_hashes(group_values, &self.random_state, batch_hashes)?;
-
- let mut allocated = 0;
- let starting_num_groups = self.group_values.num_rows();
- for (row, &hash) in batch_hashes.iter().enumerate() {
- let entry = self.map.get_mut(hash, |(_hash, group_idx)| {
- // verify that a group that we are inserting with hash is
- // actually the same key value as the group in
- // existing_idx (aka group_values @ row)
- group_rows.row(row) == self.group_values.row(*group_idx)
- });
-
- let group_idx = match entry {
- // Existing group_index for this group value
- Some((_hash, group_idx)) => *group_idx,
- // 1.2 Need to create new entry for the group
- None => {
- // Add new entry to aggr_state and save newly created index
- let group_idx = self.group_values.num_rows();
- self.group_values.push(group_rows.row(row));
-
- // for hasher function, use precomputed hash value
- self.map.insert_accounted(
- (hash, group_idx),
- |(hash, _group_index)| *hash,
- &mut allocated,
- );
- group_idx
- }
- };
- group_indices.push(group_idx);
- }
- memory_delta.inc(allocated);
-
- // Update ordering information if necessary
- let total_num_groups = self.group_values.num_rows();
- if total_num_groups > starting_num_groups {
- self.group_ordering.new_groups(
- group_values,
- group_indices,
- total_num_groups,
- )?;
- }
-
- // account for memory change
- memory_delta.inc(self.state_size());
-
- Ok(())
- }
-
/// Perform group-by aggregation for the given [`RecordBatch`].
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> Result<()> {
// Evaluate the grouping expressions
let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
- // Keep track of memory allocated:
- let mut memory_delta = MemoryDelta::new();
- memory_delta.dec(self.state_size());
-
// Evaluate the aggregation expressions.
let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
@@ -512,8 +565,20 @@ impl GroupedHashAggregateStream {
for group_values in &group_by_values {
// calculate the group indices for each input row
- self.update_group_state(group_values, &mut memory_delta)?;
- let group_indices = &self.scratch_space.current_group_indices;
+ let starting_num_groups = self.group_values.len();
+ self.group_values
+ .intern(group_values, &mut self.current_group_indices)?;
+ let group_indices = &self.current_group_indices;
+
+ // Update ordering information if necessary
+ let total_num_groups = self.group_values.len();
+ if total_num_groups > starting_num_groups {
+ self.group_ordering.new_groups(
+ group_values,
+ group_indices,
+ total_num_groups,
+ )?;
+ }
// Gather the inputs to call the actual accumulator
let t = self
@@ -522,10 +587,7 @@ impl GroupedHashAggregateStream {
.zip(input_values.iter())
.zip(filter_values.iter());
- let total_num_groups = self.group_values.num_rows();
-
for ((acc, values), opt_filter) in t {
- memory_delta.dec(acc.size());
let opt_filter = opt_filter.as_ref().map(|filter|
filter.as_boolean());
// Call the appropriate method on each aggregator with
@@ -552,42 +614,32 @@ impl GroupedHashAggregateStream {
)?;
}
}
- memory_delta.inc(acc.size());
}
}
- memory_delta.inc(self.state_size());
- // Update allocation AFTER it is used, simplifying accounting,
- // though it results in a temporary overshoot.
- memory_delta.update(&mut self.reservation)
+ self.update_memory_reservation()
+ }
+
+ fn update_memory_reservation(&mut self) -> Result<()> {
+ let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
+ self.reservation.try_resize(
+ acc + self.group_values.size()
+ + self.group_ordering.size()
+ + self.current_group_indices.allocated_size(),
+ )
}
/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
- fn create_batch_from_map(&mut self, emit_to: EmitTo) ->
Result<RecordBatch> {
- if self.group_values.num_rows() == 0 {
+ fn emit(&mut self, emit_to: EmitTo) -> Result<RecordBatch> {
+ if self.group_values.is_empty() {
return Ok(RecordBatch::new_empty(self.schema()));
}
- let output = self.build_output(emit_to)?;
- self.remove_emitted(emit_to)?;
- let batch = RecordBatch::try_new(self.schema(), output)?;
- Ok(batch)
- }
-
- /// Creates output: `(group 1, group 2, ... agg 1, agg 2, ...)`
- fn build_output(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
- // First output rows are the groups
- let mut output: Vec<ArrayRef> = match emit_to {
- EmitTo::All => {
- let groups_rows = self.group_values.iter();
- self.row_converter.convert_rows(groups_rows)?
- }
- EmitTo::First(n) => {
- let groups_rows = self.group_values.iter().take(n);
- self.row_converter.convert_rows(groups_rows)?
- }
- };
+ let mut output = self.group_values.emit(emit_to)?;
+ if let EmitTo::First(n) = emit_to {
+ self.group_ordering.remove_groups(n);
+ }
// Next output each aggregate value
for acc in self.accumulators.iter_mut() {
@@ -600,78 +652,8 @@ impl GroupedHashAggregateStream {
}
}
- Ok(output)
- }
-
- /// Removes the first `n` groups, adjusting all group_indices
- /// appropriately
- fn remove_emitted(&mut self, emit_to: EmitTo) -> Result<()> {
- let mut memory_delta = MemoryDelta::new();
- memory_delta.dec(self.state_size());
-
- match emit_to {
- EmitTo::All => {
- // Eventually we may also want to clear the hash table here
- //self.map.clear();
- }
- EmitTo::First(n) => {
- // Clear out first n group keys by copying them to a new Rows.
- // TODO file some ticket in arrow-rs to make this more
efficent?
- let mut new_group_values = self.row_converter.empty_rows(0, 0);
- for row in self.group_values.iter().skip(n) {
- new_group_values.push(row);
- }
- std::mem::swap(&mut new_group_values, &mut self.group_values);
-
- self.group_ordering.remove_groups(n);
- // SAFETY: self.map outlives iterator and is not modified
concurrently
- unsafe {
- for bucket in self.map.iter() {
- // Decrement group index by n
- match bucket.as_ref().1.checked_sub(n) {
- // Group index was >= n, shift value down
- Some(sub) => bucket.as_mut().1 = sub,
- // Group index was < n, so remove from table
- None => self.map.erase(bucket),
- }
- }
- }
- }
- };
- // account for memory change
- memory_delta.inc(self.state_size());
- memory_delta.update(&mut self.reservation)
- }
-
- /// return the current size stored by variable state in this structure
- fn state_size(&self) -> usize {
- self.group_values.size()
- + self.scratch_space.size()
- + self.group_ordering.size()
- + self.row_converter.size()
- }
-}
-
-/// Holds structures used for the current input [`RecordBatch`] being
-/// processed. Reused across batches here to avoid reallocations
-#[derive(Debug, Default)]
-struct ScratchSpace {
- /// scratch space for the current input [`RecordBatch`] being
- /// processed. Reused across batches here to avoid reallocations
- current_group_indices: Vec<usize>,
- // buffer to be reused to store hashes
- hashes_buffer: Vec<u64>,
-}
-
-impl ScratchSpace {
- fn new() -> Self {
- Default::default()
- }
-
- /// Return the amount of memory alocated by this structure in bytes
- fn size(&self) -> usize {
- std::mem::size_of_val(self)
- + self.current_group_indices.allocated_size()
- + self.hashes_buffer.allocated_size()
+ self.update_memory_reservation()?;
+ let batch = RecordBatch::try_new(self.schema(), output)?;
+ Ok(batch)
}
}
diff --git a/datafusion/execution/src/memory_pool/mod.rs
b/datafusion/execution/src/memory_pool/mod.rs
index fe077524a4..011cd72cbb 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -219,49 +219,6 @@ pub fn human_readable_size(size: usize) -> String {
format!("{value:.1} {unit}")
}
-/// Tracks the change in memory to avoid overflow. Typically, this
-/// is isued like the following
-///
-/// 1. Call `delta.dec(sized_thing.size())`
-///
-/// 2. potentially change size of `sized_thing`
-///
-/// 3. Call `delta.inc(size_thing.size())`
-#[derive(Debug, Default)]
-pub struct MemoryDelta {
- decrease: usize,
- increase: usize,
-}
-
-impl MemoryDelta {
- pub fn new() -> Self {
- Default::default()
- }
-
- /// record size being 'decremented'. This is used for to record the
- /// initial size of some allocation prior to hange
- pub fn dec(&mut self, sz: usize) {
- self.decrease += sz;
- }
-
- /// record size being 'incremented'. This is used for to record
- /// the final size of some object.
- pub fn inc(&mut self, sz: usize) {
- self.increase += sz;
- }
-
- /// Adjusts the reservation with the delta used / freed
- pub fn update(self, reservation: &mut MemoryReservation) -> Result<()> {
- let Self { decrease, increase } = self;
- match increase.cmp(&decrease) {
- Ordering::Less => reservation.shrink(decrease - increase),
- Ordering::Equal => {}
- Ordering::Greater => reservation.try_grow(increase - decrease)?,
- };
- Ok(())
- }
-}
-
#[cfg(test)]
mod tests {
use super::*;