This is an automated email from the ASF dual-hosted git repository.
akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new ca60ff13ad Streaming Memory Reservation in SHJ (#5937)
ca60ff13ad is described below
commit ca60ff13ad88ede7ea041fb1fe4cfdf0144f9cb4
Author: Metehan Yıldırım <[email protected]>
AuthorDate: Fri Apr 14 00:14:58 2023 +0300
Streaming Memory Reservation in SHJ (#5937)
* Memory mngr
* Minor changes on merge letfovers
* Added try_resize and additional memory measurements
* Update cp_solver.rs
* Memory testing
---
datafusion/core/src/datasource/streaming.rs | 13 ++-
.../core/src/physical_plan/joins/hash_join.rs | 23 +----
.../src/physical_plan/joins/hash_join_utils.rs | 90 ++++++++++++++++-
.../src/physical_plan/joins/symmetric_hash_join.rs | 66 ++++++++++++-
datafusion/core/src/physical_plan/streaming.rs | 7 ++
datafusion/core/tests/memory_limit.rs | 108 +++++++++++++++++++++
datafusion/execution/src/memory_pool/mod.rs | 11 +++
.../physical-expr/src/intervals/cp_solver.rs | 13 +++
8 files changed, 304 insertions(+), 27 deletions(-)
diff --git a/datafusion/core/src/datasource/streaming.rs
b/datafusion/core/src/datasource/streaming.rs
index b3dc60c8c8..4a234fbe13 100644
--- a/datafusion/core/src/datasource/streaming.rs
+++ b/datafusion/core/src/datasource/streaming.rs
@@ -44,6 +44,7 @@ pub trait PartitionStream: Send + Sync {
pub struct StreamingTable {
schema: SchemaRef,
partitions: Vec<Arc<dyn PartitionStream>>,
+ infinite: bool,
}
impl StreamingTable {
@@ -58,7 +59,16 @@ impl StreamingTable {
));
}
- Ok(Self { schema, partitions })
+ Ok(Self {
+ schema,
+ partitions,
+ infinite: false,
+ })
+ }
+ /// Sets streaming table can be infinite.
+ pub fn with_infinite_table(mut self, infinite: bool) -> Self {
+ self.infinite = infinite;
+ self
}
}
@@ -88,6 +98,7 @@ impl TableProvider for StreamingTable {
self.schema.clone(),
self.partitions.clone(),
projection,
+ self.infinite,
)?))
}
}
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs
b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 46f2d92903..c624c5c17d 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -44,7 +44,7 @@ use arrow::{
};
use futures::{ready, Stream, StreamExt, TryStreamExt};
use hashbrown::raw::RawTable;
-use smallvec::{smallvec, SmallVec};
+use smallvec::smallvec;
use std::fmt;
use std::sync::Arc;
use std::task::Poll;
@@ -87,26 +87,7 @@ use super::{
utils::{OnceAsync, OnceFut},
PartitionMode,
};
-
-// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
-//
-// Note that the `u64` keys are not stored in the hashmap (hence the `()` as
key), but are only used
-// to put the indices in a certain bucket.
-// By allocating a `HashMap` with capacity for *at least* the number of rows
for entries at the build side,
-// we make sure that we don't have to re-hash the hashmap, which needs access
to the key (the hash in this case) value.
-// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8
for hash value 1
-// As the key is a hash value, we need to check possible hash collisions in
the probe stage
-// During this stage it might be the case that a row is contained the same
hashmap value,
-// but the values don't match. Those are checked in the [equal_rows] macro
-// TODO: speed up collision check and move away from using a hashbrown HashMap
-// https://github.com/apache/arrow-datafusion/issues/50
-pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>);
-
-impl fmt::Debug for JoinHashMap {
- fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
- Ok(())
- }
-}
+use crate::physical_plan::joins::hash_join_utils::JoinHashMap;
type JoinLeftData = (JoinHashMap, RecordBatch);
diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
index ffba78c1d3..f411370ef0 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
@@ -20,7 +20,7 @@
use std::collections::HashMap;
use std::sync::Arc;
-use std::usize;
+use std::{fmt, usize};
use arrow::datatypes::SchemaRef;
@@ -29,10 +29,57 @@ use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::intervals::Interval;
use datafusion_physical_expr::utils::collect_columns;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use hashbrown::raw::RawTable;
+use smallvec::SmallVec;
use crate::common::Result;
use crate::physical_plan::joins::utils::{JoinFilter, JoinSide};
+// Maps a `u64` hash value based on the build side ["on" values] to a list of
indices with this key's value.
+//
+// Note that the `u64` keys are not stored in the hashmap (hence the `()` as
key), but are only used
+// to put the indices in a certain bucket.
+// By allocating a `HashMap` with capacity for *at least* the number of rows
for entries at the build side,
+// we make sure that we don't have to re-hash the hashmap, which needs access
to the key (the hash in this case) value.
+// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8
for hash value 1
+// As the key is a hash value, we need to check possible hash collisions in
the probe stage
+// During this stage it might be the case that a row is contained the same
hashmap value,
+// but the values don't match. Those are checked in the [equal_rows] macro
+// TODO: speed up collision check and move away from using a hashbrown HashMap
+// https://github.com/apache/arrow-datafusion/issues/50
+pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>);
+
+impl JoinHashMap {
+ /// In this implementation, the scale_factor variable determines how
conservative the shrinking strategy is.
+ /// The value of scale_factor is set to 4, which means the capacity will
be reduced by 25%
+ /// when necessary. You can adjust the scale_factor value to achieve the
desired
+ /// ,balance between memory usage and performance.
+ //
+ // If you increase the scale_factor, the capacity will shrink less
aggressively,
+ // leading to potentially higher memory usage but fewer resizes.
+ // Conversely, if you decrease the scale_factor, the capacity will shrink
more aggressively,
+ // potentially leading to lower memory usage but more frequent resizing.
+ pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) {
+ let capacity = self.0.capacity();
+ let len = self.0.len();
+
+ if capacity > scale_factor * len {
+ let new_capacity = (capacity * (scale_factor - 1)) / scale_factor;
+ self.0.shrink_to(new_capacity, |(hash, _)| *hash)
+ }
+ }
+
+ pub(crate) fn size(&self) -> usize {
+ self.0.allocation_info().1.size()
+ }
+}
+
+impl fmt::Debug for JoinHashMap {
+ fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
+ Ok(())
+ }
+}
+
fn check_filter_expr_contains_sort_information(
expr: &Arc<dyn PhysicalExpr>,
reference: &Arc<dyn PhysicalExpr>,
@@ -243,6 +290,7 @@ pub mod tests {
use datafusion_common::ScalarValue;
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::{binary, cast, col, lit};
+ use smallvec::smallvec;
use std::sync::Arc;
/// Filter expr for a + b > c + 10 AND a + b < c + 100
@@ -576,4 +624,44 @@ pub mod tests {
assert!(res.is_none());
Ok(())
}
+
+ #[test]
+ fn test_shrink_if_necessary() {
+ let scale_factor = 4;
+ let mut join_hash_map = JoinHashMap(RawTable::with_capacity(100));
+ let data_size = 2000;
+ let deleted_part = 3 * data_size / 4;
+ // Add elements to the JoinHashMap
+ for hash_value in 0..data_size {
+ join_hash_map.0.insert(
+ hash_value,
+ (hash_value, smallvec![hash_value]),
+ |(hash, _)| *hash,
+ );
+ }
+
+ assert_eq!(join_hash_map.0.len(), data_size as usize);
+ assert!(join_hash_map.0.capacity() >= data_size as usize);
+
+ // Remove some elements from the JoinHashMap
+ for hash_value in 0..deleted_part {
+ join_hash_map
+ .0
+ .remove_entry(hash_value, |(hash, _)| hash_value == *hash);
+ }
+
+ assert_eq!(join_hash_map.0.len(), (data_size - deleted_part) as usize);
+
+ // Old capacity
+ let old_capacity = join_hash_map.0.capacity();
+
+ // Test shrink_if_necessary
+ join_hash_map.shrink_if_necessary(scale_factor);
+
+ // The capacity should be reduced by the scale factor
+ let new_expected_capacity =
+ join_hash_map.0.capacity() * (scale_factor - 1) / scale_factor;
+ assert!(join_hash_map.0.capacity() >= new_expected_capacity);
+ assert!(join_hash_map.0.capacity() <= old_capacity);
+ }
}
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index f5d06a8c0a..47c511a1a6 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -47,17 +47,20 @@ use hashbrown::{raw::RawTable, HashSet};
use parking_lot::Mutex;
use datafusion_common::{utils::bisect, ScalarValue};
+use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval};
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::logical_expr::JoinType;
+use crate::physical_plan::common::SharedMemoryReservation;
use
crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema;
+use crate::physical_plan::joins::hash_join_utils::JoinHashMap;
use crate::physical_plan::{
expressions::Column,
expressions::PhysicalSortExpr,
joins::{
- hash_join::{build_join_indices, update_hash, JoinHashMap},
+ hash_join::{build_join_indices, update_hash},
hash_join_utils::{build_filter_input_order, SortedFilterExpr},
utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
@@ -70,6 +73,8 @@ use crate::physical_plan::{
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
+const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
+
/// A symmetric hash join with range conditions is when both streams are
hashed on the
/// join key and the resulting hash tables are used to join the streams.
/// The join is considered symmetric because the hash table is built on the
join keys from both
@@ -209,6 +214,8 @@ struct SymmetricHashJoinMetrics {
left: SymmetricHashJoinSideMetrics,
/// Number of right batches/rows consumed by this operator
right: SymmetricHashJoinSideMetrics,
+ /// Memory used by sides in bytes
+ pub(crate) stream_memory_usage: metrics::Gauge,
/// Number of batches produced by this operator
output_batches: metrics::Count,
/// Number of rows produced by this operator
@@ -233,6 +240,9 @@ impl SymmetricHashJoinMetrics {
input_rows,
};
+ let stream_memory_usage =
+ MetricBuilder::new(metrics).gauge("stream_memory_usage",
partition);
+
let output_batches =
MetricBuilder::new(metrics).counter("output_batches", partition);
@@ -242,6 +252,7 @@ impl SymmetricHashJoinMetrics {
left,
right,
output_batches,
+ stream_memory_usage,
output_rows,
}
}
@@ -581,7 +592,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
let right_stream = self
.right
- .execute(partition, context)?
+ .execute(partition, context.clone())?
.map(|val| (JoinSide::Right, val));
// This function will attempt to pull items from both streams.
// Each stream will be polled in a round-robin fashion, and whenever a
stream is
@@ -590,6 +601,14 @@ impl ExecutionPlan for SymmetricHashJoinExec {
// The returned stream completes when both input streams have
completed.
let input_stream = select(left_stream, right_stream).boxed();
+ let reservation = Arc::new(Mutex::new(
+
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
+ .register(context.memory_pool()),
+ ));
+ if let Some(g) = graph.as_ref() {
+ reservation.lock().try_grow(g.size())?;
+ }
+
Ok(Box::pin(SymmetricHashJoinStream {
input_stream,
schema: self.schema(),
@@ -605,6 +624,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
right_sorted_filter_expr,
null_equals_null: self.null_equals_null,
final_result: false,
+ reservation,
}))
}
}
@@ -637,6 +657,8 @@ struct SymmetricHashJoinStream {
null_equals_null: bool,
/// Metrics
metrics: SymmetricHashJoinMetrics,
+ /// Memory reservation
+ reservation: SharedMemoryReservation,
/// Flag indicating whether there is nothing to process anymore
final_result: bool,
}
@@ -689,6 +711,7 @@ fn prune_hash_values(
}
}
}
+ hashmap.shrink_if_necessary(HASHMAP_SHRINK_SCALE_FACTOR);
Ok(())
}
@@ -1041,12 +1064,26 @@ struct OneSideHashJoiner {
}
impl OneSideHashJoiner {
+ pub fn size(&self) -> usize {
+ let mut size = 0;
+ size += std::mem::size_of_val(self);
+ size += std::mem::size_of_val(&self.build_side);
+ size += self.input_buffer.get_array_memory_size();
+ size += std::mem::size_of_val(&self.on);
+ size += self.hashmap.size();
+ size += self.row_hash_values.capacity() * std::mem::size_of::<u64>();
+ size += self.hashes_buffer.capacity() * std::mem::size_of::<u64>();
+ size += self.visited_rows.capacity() * std::mem::size_of::<usize>();
+ size += std::mem::size_of_val(&self.offset);
+ size += std::mem::size_of_val(&self.deleted_offset);
+ size
+ }
pub fn new(build_side: JoinSide, on: Vec<Column>, schema: SchemaRef) ->
Self {
Self {
build_side,
input_buffer: RecordBatch::new_empty(schema),
on,
- hashmap: JoinHashMap(RawTable::with_capacity(10_000)),
+ hashmap: JoinHashMap(RawTable::with_capacity(0)),
row_hash_values: VecDeque::new(),
hashes_buffer: vec![],
visited_rows: HashSet::new(),
@@ -1074,6 +1111,7 @@ impl OneSideHashJoiner {
self.input_buffer = concat_batches(&batch.schema(),
[&self.input_buffer, batch])?;
// Resize the hashes buffer to the number of rows in the incoming
batch:
self.hashes_buffer.resize(batch.num_rows(), 0);
+ // Get allocation_info before adding the item
// Update the hashmap with the join key values and hashes of the
incoming batch:
update_hash(
&self.on,
@@ -1339,6 +1377,24 @@ fn combine_two_batches(
}
impl SymmetricHashJoinStream {
+ fn size(&self) -> usize {
+ let mut size = 0;
+ size += std::mem::size_of_val(&self.input_stream);
+ size += std::mem::size_of_val(&self.schema);
+ size += std::mem::size_of_val(&self.filter);
+ size += std::mem::size_of_val(&self.join_type);
+ size += self.left.size();
+ size += self.right.size();
+ size += std::mem::size_of_val(&self.column_indices);
+ size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
+ size += std::mem::size_of_val(&self.left_sorted_filter_expr);
+ size += std::mem::size_of_val(&self.right_sorted_filter_expr);
+ size += std::mem::size_of_val(&self.random_state);
+ size += std::mem::size_of_val(&self.null_equals_null);
+ size += std::mem::size_of_val(&self.metrics);
+ size += std::mem::size_of_val(&self.final_result);
+ size
+ }
/// Polls the next result of the join operation.
///
/// If the result of the join is ready, it returns the next record batch.
@@ -1442,6 +1498,9 @@ impl SymmetricHashJoinStream {
// Combine results:
let result =
combine_two_batches(&self.schema, equal_result,
anti_result)?;
+ let capacity = self.size();
+ self.metrics.stream_memory_usage.set(capacity);
+ self.reservation.lock().try_resize(capacity)?;
// Update the metrics if we have a batch; otherwise,
continue the loop.
if let Some(batch) = &result {
self.metrics.output_batches.add(1);
@@ -1482,7 +1541,6 @@ impl SymmetricHashJoinStream {
// Update the metrics:
self.metrics.output_batches.add(1);
self.metrics.output_rows.add(batch.num_rows());
-
return Poll::Ready(Ok(result).transpose());
}
}
diff --git a/datafusion/core/src/physical_plan/streaming.rs
b/datafusion/core/src/physical_plan/streaming.rs
index efd43aca6b..ff5d88dd31 100644
--- a/datafusion/core/src/physical_plan/streaming.rs
+++ b/datafusion/core/src/physical_plan/streaming.rs
@@ -37,6 +37,7 @@ pub struct StreamingTableExec {
partitions: Vec<Arc<dyn PartitionStream>>,
projection: Option<Arc<[usize]>>,
projected_schema: SchemaRef,
+ infinite: bool,
}
impl StreamingTableExec {
@@ -45,6 +46,7 @@ impl StreamingTableExec {
schema: SchemaRef,
partitions: Vec<Arc<dyn PartitionStream>>,
projection: Option<&Vec<usize>>,
+ infinite: bool,
) -> Result<Self> {
if !partitions.iter().all(|x| schema.contains(x.schema())) {
return Err(DataFusionError::Plan(
@@ -61,6 +63,7 @@ impl StreamingTableExec {
partitions,
projected_schema,
projection: projection.cloned().map(Into::into),
+ infinite,
})
}
}
@@ -85,6 +88,10 @@ impl ExecutionPlan for StreamingTableExec {
Partitioning::UnknownPartitioning(self.partitions.len())
}
+ fn unbounded_output(&self, _children: &[bool]) -> Result<bool> {
+ Ok(self.infinite)
+ }
+
fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
None
}
diff --git a/datafusion/core/tests/memory_limit.rs
b/datafusion/core/tests/memory_limit.rs
index 13b6291f9c..71f7f30bf0 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -17,15 +17,23 @@
//! This module contains tests for limiting memory at runtime in DataFusion
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use futures::StreamExt;
use std::sync::Arc;
+use datafusion::datasource::streaming::{PartitionStream, StreamingTable};
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionState;
use datafusion::execution::disk_manager::DiskManagerConfig;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::physical_optimizer::pipeline_fixer::PipelineFixer;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use datafusion::physical_plan::SendableRecordBatchStream;
use datafusion_common::assert_contains;
use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion_execution::TaskContext;
use test_utils::AccessLogGenerator;
#[cfg(test)]
@@ -161,6 +169,23 @@ async fn merge_join() {
.await
}
+#[tokio::test]
+async fn test_limit_symmetric_hash_join() {
+ let config = SessionConfig::new();
+
+ run_streaming_test_with_config(
+ "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time =
t2.time",
+ vec![
+ "Resources exhausted: Failed to allocate additional",
+ "SymmetricHashJoinStream",
+ ],
+ 1_000,
+ config,
+ )
+ .await
+}
+
+/// 50 byte memory limit
const MEMORY_FRACTION: f64 = 0.95;
/// runs the specified query against 1000 rows with specified
@@ -219,3 +244,86 @@ async fn run_limit_test_with_config(
}
}
}
+
+struct DummyStreamPartition {
+ schema: SchemaRef,
+ batches: Vec<RecordBatch>,
+}
+
+impl PartitionStream for DummyStreamPartition {
+ fn schema(&self) -> &SchemaRef {
+ &self.schema
+ }
+
+ fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
+ // We create an iterator from the record batches and map them into Ok
values,
+ // converting the iterator into a futures::stream::Stream
+ Box::pin(RecordBatchStreamAdapter::new(
+ self.schema.clone(),
+ futures::stream::iter(self.batches.clone()).map(Ok),
+ ))
+ }
+}
+
+async fn run_streaming_test_with_config(
+ query: &str,
+ expected_error_contains: Vec<&str>,
+ memory_limit: usize,
+ config: SessionConfig,
+) {
+ // Generate a set of access logs with a row limit of 1000 and a max batch
size of 50
+ let batches: Vec<_> = AccessLogGenerator::new()
+ .with_row_limit(1000)
+ .with_max_batch_size(50)
+ .collect();
+
+ // Create a new streaming table with the generated schema and batches
+ let table = StreamingTable::try_new(
+ batches[0].schema(),
+ vec![Arc::new(DummyStreamPartition {
+ schema: batches[0].schema(),
+ batches: batches.clone(),
+ })],
+ )
+ .unwrap()
+ .with_infinite_table(true);
+
+ // Configure the runtime environment with custom settings
+ let rt_config = RuntimeConfig::new()
+ // Disable disk manager to disallow spilling
+ .with_disk_manager(DiskManagerConfig::Disabled)
+ // Set memory limit to 50 bytes
+ .with_memory_limit(memory_limit, MEMORY_FRACTION);
+
+ // Create a new runtime environment with the configured settings
+ let runtime = RuntimeEnv::new(rt_config).unwrap();
+
+ // Create a new session state with the given configuration and runtime
environment
+ // Disable all physical optimizer rules except the PipelineFixer rule to
avoid sorts or
+ // repartition, as they also have memory budgets that may be hit first
+ let state = SessionState::with_config_rt(config, Arc::new(runtime))
+ .with_physical_optimizer_rules(vec![Arc::new(PipelineFixer::new())]);
+
+ // Create a new session context with the session state
+ let ctx = SessionContext::with_state(state);
+ // Register the streaming table with the session context
+ ctx.register_table("t", Arc::new(table))
+ .expect("registering table");
+
+ // Execute the SQL query and get a DataFrame
+ let df = ctx.sql(query).await.expect("Planning query");
+
+ // Collect the results of the DataFrame execution
+ match df.collect().await {
+ // If the execution succeeds, panic as we expect memory limit failure
+ Ok(_batches) => {
+ panic!("Unexpected success when running, expected memory limit
failure")
+ }
+ // If the execution fails, verify if the error contains the expected
substrings
+ Err(e) => {
+ for error_substring in expected_error_contains {
+ assert_contains!(e.to_string(), error_substring);
+ }
+ }
+ }
+}
diff --git a/datafusion/execution/src/memory_pool/mod.rs
b/datafusion/execution/src/memory_pool/mod.rs
index f1f745d4eb..066cec9fb1 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -143,6 +143,17 @@ impl MemoryReservation {
}
}
+ /// Try to set the size of this reservation to `capacity`
+ pub fn try_resize(&mut self, capacity: usize) -> Result<()> {
+ use std::cmp::Ordering;
+ match capacity.cmp(&self.size) {
+ Ordering::Greater => self.try_grow(capacity - self.size)?,
+ Ordering::Less => self.shrink(self.size - capacity),
+ _ => {}
+ };
+ Ok(())
+ }
+
/// Increase the size of this reservation by `capacity` bytes
pub fn grow(&mut self, capacity: usize) {
self.policy.grow(self, capacity);
diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs
b/datafusion/physical-expr/src/intervals/cp_solver.rs
index a54bcb2f35..5e9353599e 100644
--- a/datafusion/physical-expr/src/intervals/cp_solver.rs
+++ b/datafusion/physical-expr/src/intervals/cp_solver.rs
@@ -122,6 +122,19 @@ pub struct ExprIntervalGraph {
root: NodeIndex,
}
+impl ExprIntervalGraph {
+ /// Estimate size of bytes including `Self`.
+ pub fn size(&self) -> usize {
+ let node_memory_usage = self.graph.node_count()
+ * (std::mem::size_of::<ExprIntervalGraphNode>()
+ + std::mem::size_of::<NodeIndex>());
+ let edge_memory_usage = self.graph.edge_count()
+ * (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>()
* 2);
+
+ std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage
+ }
+}
+
/// This object encapsulates all possible constraint propagation results.
#[derive(PartialEq, Debug)]
pub enum PropagationResult {