This is an automated email from the ASF dual-hosted git repository.

akurmustafa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new ca60ff13ad Streaming Memory Reservation in SHJ (#5937)
ca60ff13ad is described below

commit ca60ff13ad88ede7ea041fb1fe4cfdf0144f9cb4
Author: Metehan Yıldırım <100111937+metesynn...@users.noreply.github.com>
AuthorDate: Fri Apr 14 00:14:58 2023 +0300

    Streaming Memory Reservation in SHJ (#5937)
    
    * Memory mngr
    
    * Minor changes on merge letfovers
    
    * Added try_resize and additional memory measurements
    
    * Update cp_solver.rs
    
    * Memory testing
---
 datafusion/core/src/datasource/streaming.rs        |  13 ++-
 .../core/src/physical_plan/joins/hash_join.rs      |  23 +----
 .../src/physical_plan/joins/hash_join_utils.rs     |  90 ++++++++++++++++-
 .../src/physical_plan/joins/symmetric_hash_join.rs |  66 ++++++++++++-
 datafusion/core/src/physical_plan/streaming.rs     |   7 ++
 datafusion/core/tests/memory_limit.rs              | 108 +++++++++++++++++++++
 datafusion/execution/src/memory_pool/mod.rs        |  11 +++
 .../physical-expr/src/intervals/cp_solver.rs       |  13 +++
 8 files changed, 304 insertions(+), 27 deletions(-)

diff --git a/datafusion/core/src/datasource/streaming.rs 
b/datafusion/core/src/datasource/streaming.rs
index b3dc60c8c8..4a234fbe13 100644
--- a/datafusion/core/src/datasource/streaming.rs
+++ b/datafusion/core/src/datasource/streaming.rs
@@ -44,6 +44,7 @@ pub trait PartitionStream: Send + Sync {
 pub struct StreamingTable {
     schema: SchemaRef,
     partitions: Vec<Arc<dyn PartitionStream>>,
+    infinite: bool,
 }
 
 impl StreamingTable {
@@ -58,7 +59,16 @@ impl StreamingTable {
             ));
         }
 
-        Ok(Self { schema, partitions })
+        Ok(Self {
+            schema,
+            partitions,
+            infinite: false,
+        })
+    }
+    /// Sets streaming table can be infinite.
+    pub fn with_infinite_table(mut self, infinite: bool) -> Self {
+        self.infinite = infinite;
+        self
     }
 }
 
@@ -88,6 +98,7 @@ impl TableProvider for StreamingTable {
             self.schema.clone(),
             self.partitions.clone(),
             projection,
+            self.infinite,
         )?))
     }
 }
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs 
b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 46f2d92903..c624c5c17d 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -44,7 +44,7 @@ use arrow::{
 };
 use futures::{ready, Stream, StreamExt, TryStreamExt};
 use hashbrown::raw::RawTable;
-use smallvec::{smallvec, SmallVec};
+use smallvec::smallvec;
 use std::fmt;
 use std::sync::Arc;
 use std::task::Poll;
@@ -87,26 +87,7 @@ use super::{
     utils::{OnceAsync, OnceFut},
     PartitionMode,
 };
-
-// Maps a `u64` hash value based on the build side ["on" values] to a list of 
indices with this key's value.
-//
-// Note that the `u64` keys are not stored in the hashmap (hence the `()` as 
key), but are only used
-// to put the indices in a certain bucket.
-// By allocating a `HashMap` with capacity for *at least* the number of rows 
for entries at the build side,
-// we make sure that we don't have to re-hash the hashmap, which needs access 
to the key (the hash in this case) value.
-// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 
for hash value 1
-// As the key is a hash value, we need to check possible hash collisions in 
the probe stage
-// During this stage it might be the case that a row is contained the same 
hashmap value,
-// but the values don't match. Those are checked in the [equal_rows] macro
-// TODO: speed up collision check and move away from using a hashbrown HashMap
-// https://github.com/apache/arrow-datafusion/issues/50
-pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>);
-
-impl fmt::Debug for JoinHashMap {
-    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
-        Ok(())
-    }
-}
+use crate::physical_plan::joins::hash_join_utils::JoinHashMap;
 
 type JoinLeftData = (JoinHashMap, RecordBatch);
 
diff --git a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs 
b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
index ffba78c1d3..f411370ef0 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join_utils.rs
@@ -20,7 +20,7 @@
 
 use std::collections::HashMap;
 use std::sync::Arc;
-use std::usize;
+use std::{fmt, usize};
 
 use arrow::datatypes::SchemaRef;
 
@@ -29,10 +29,57 @@ use datafusion_physical_expr::expressions::Column;
 use datafusion_physical_expr::intervals::Interval;
 use datafusion_physical_expr::utils::collect_columns;
 use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
+use hashbrown::raw::RawTable;
+use smallvec::SmallVec;
 
 use crate::common::Result;
 use crate::physical_plan::joins::utils::{JoinFilter, JoinSide};
 
+// Maps a `u64` hash value based on the build side ["on" values] to a list of 
indices with this key's value.
+//
+// Note that the `u64` keys are not stored in the hashmap (hence the `()` as 
key), but are only used
+// to put the indices in a certain bucket.
+// By allocating a `HashMap` with capacity for *at least* the number of rows 
for entries at the build side,
+// we make sure that we don't have to re-hash the hashmap, which needs access 
to the key (the hash in this case) value.
+// E.g. 1 -> [3, 6, 8] indicates that the column values map to rows 3, 6 and 8 
for hash value 1
+// As the key is a hash value, we need to check possible hash collisions in 
the probe stage
+// During this stage it might be the case that a row is contained the same 
hashmap value,
+// but the values don't match. Those are checked in the [equal_rows] macro
+// TODO: speed up collision check and move away from using a hashbrown HashMap
+// https://github.com/apache/arrow-datafusion/issues/50
+pub struct JoinHashMap(pub RawTable<(u64, SmallVec<[u64; 1]>)>);
+
+impl JoinHashMap {
+    /// In this implementation, the scale_factor variable determines how 
conservative the shrinking strategy is.
+    /// The value of scale_factor is set to 4, which means the capacity will 
be reduced by 25%
+    /// when necessary. You can adjust the scale_factor value to achieve the 
desired
+    /// ,balance between memory usage and performance.
+    //
+    // If you increase the scale_factor, the capacity will shrink less 
aggressively,
+    // leading to potentially higher memory usage but fewer resizes.
+    // Conversely, if you decrease the scale_factor, the capacity will shrink 
more aggressively,
+    // potentially leading to lower memory usage but more frequent resizing.
+    pub(crate) fn shrink_if_necessary(&mut self, scale_factor: usize) {
+        let capacity = self.0.capacity();
+        let len = self.0.len();
+
+        if capacity > scale_factor * len {
+            let new_capacity = (capacity * (scale_factor - 1)) / scale_factor;
+            self.0.shrink_to(new_capacity, |(hash, _)| *hash)
+        }
+    }
+
+    pub(crate) fn size(&self) -> usize {
+        self.0.allocation_info().1.size()
+    }
+}
+
+impl fmt::Debug for JoinHashMap {
+    fn fmt(&self, _f: &mut fmt::Formatter) -> fmt::Result {
+        Ok(())
+    }
+}
+
 fn check_filter_expr_contains_sort_information(
     expr: &Arc<dyn PhysicalExpr>,
     reference: &Arc<dyn PhysicalExpr>,
@@ -243,6 +290,7 @@ pub mod tests {
     use datafusion_common::ScalarValue;
     use datafusion_expr::Operator;
     use datafusion_physical_expr::expressions::{binary, cast, col, lit};
+    use smallvec::smallvec;
     use std::sync::Arc;
 
     /// Filter expr for a + b > c + 10 AND a + b < c + 100
@@ -576,4 +624,44 @@ pub mod tests {
         assert!(res.is_none());
         Ok(())
     }
+
+    #[test]
+    fn test_shrink_if_necessary() {
+        let scale_factor = 4;
+        let mut join_hash_map = JoinHashMap(RawTable::with_capacity(100));
+        let data_size = 2000;
+        let deleted_part = 3 * data_size / 4;
+        // Add elements to the JoinHashMap
+        for hash_value in 0..data_size {
+            join_hash_map.0.insert(
+                hash_value,
+                (hash_value, smallvec![hash_value]),
+                |(hash, _)| *hash,
+            );
+        }
+
+        assert_eq!(join_hash_map.0.len(), data_size as usize);
+        assert!(join_hash_map.0.capacity() >= data_size as usize);
+
+        // Remove some elements from the JoinHashMap
+        for hash_value in 0..deleted_part {
+            join_hash_map
+                .0
+                .remove_entry(hash_value, |(hash, _)| hash_value == *hash);
+        }
+
+        assert_eq!(join_hash_map.0.len(), (data_size - deleted_part) as usize);
+
+        // Old capacity
+        let old_capacity = join_hash_map.0.capacity();
+
+        // Test shrink_if_necessary
+        join_hash_map.shrink_if_necessary(scale_factor);
+
+        // The capacity should be reduced by the scale factor
+        let new_expected_capacity =
+            join_hash_map.0.capacity() * (scale_factor - 1) / scale_factor;
+        assert!(join_hash_map.0.capacity() >= new_expected_capacity);
+        assert!(join_hash_map.0.capacity() <= old_capacity);
+    }
 }
diff --git a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs 
b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
index f5d06a8c0a..47c511a1a6 100644
--- a/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/symmetric_hash_join.rs
@@ -47,17 +47,20 @@ use hashbrown::{raw::RawTable, HashSet};
 use parking_lot::Mutex;
 
 use datafusion_common::{utils::bisect, ScalarValue};
+use datafusion_execution::memory_pool::MemoryConsumer;
 use datafusion_physical_expr::intervals::{ExprIntervalGraph, Interval};
 
 use crate::error::{DataFusionError, Result};
 use crate::execution::context::TaskContext;
 use crate::logical_expr::JoinType;
+use crate::physical_plan::common::SharedMemoryReservation;
 use 
crate::physical_plan::joins::hash_join_utils::convert_sort_expr_with_filter_schema;
+use crate::physical_plan::joins::hash_join_utils::JoinHashMap;
 use crate::physical_plan::{
     expressions::Column,
     expressions::PhysicalSortExpr,
     joins::{
-        hash_join::{build_join_indices, update_hash, JoinHashMap},
+        hash_join::{build_join_indices, update_hash},
         hash_join_utils::{build_filter_input_order, SortedFilterExpr},
         utils::{
             build_batch_from_indices, build_join_schema, check_join_is_valid,
@@ -70,6 +73,8 @@ use crate::physical_plan::{
     RecordBatchStream, SendableRecordBatchStream, Statistics,
 };
 
+const HASHMAP_SHRINK_SCALE_FACTOR: usize = 4;
+
 /// A symmetric hash join with range conditions is when both streams are 
hashed on the
 /// join key and the resulting hash tables are used to join the streams.
 /// The join is considered symmetric because the hash table is built on the 
join keys from both
@@ -209,6 +214,8 @@ struct SymmetricHashJoinMetrics {
     left: SymmetricHashJoinSideMetrics,
     /// Number of right batches/rows consumed by this operator
     right: SymmetricHashJoinSideMetrics,
+    /// Memory used by sides in bytes
+    pub(crate) stream_memory_usage: metrics::Gauge,
     /// Number of batches produced by this operator
     output_batches: metrics::Count,
     /// Number of rows produced by this operator
@@ -233,6 +240,9 @@ impl SymmetricHashJoinMetrics {
             input_rows,
         };
 
+        let stream_memory_usage =
+            MetricBuilder::new(metrics).gauge("stream_memory_usage", 
partition);
+
         let output_batches =
             MetricBuilder::new(metrics).counter("output_batches", partition);
 
@@ -242,6 +252,7 @@ impl SymmetricHashJoinMetrics {
             left,
             right,
             output_batches,
+            stream_memory_usage,
             output_rows,
         }
     }
@@ -581,7 +592,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
 
         let right_stream = self
             .right
-            .execute(partition, context)?
+            .execute(partition, context.clone())?
             .map(|val| (JoinSide::Right, val));
         // This function will attempt to pull items from both streams.
         // Each stream will be polled in a round-robin fashion, and whenever a 
stream is
@@ -590,6 +601,14 @@ impl ExecutionPlan for SymmetricHashJoinExec {
         // The returned stream completes when both input streams have 
completed.
         let input_stream = select(left_stream, right_stream).boxed();
 
+        let reservation = Arc::new(Mutex::new(
+            
MemoryConsumer::new(format!("SymmetricHashJoinStream[{partition}]"))
+                .register(context.memory_pool()),
+        ));
+        if let Some(g) = graph.as_ref() {
+            reservation.lock().try_grow(g.size())?;
+        }
+
         Ok(Box::pin(SymmetricHashJoinStream {
             input_stream,
             schema: self.schema(),
@@ -605,6 +624,7 @@ impl ExecutionPlan for SymmetricHashJoinExec {
             right_sorted_filter_expr,
             null_equals_null: self.null_equals_null,
             final_result: false,
+            reservation,
         }))
     }
 }
@@ -637,6 +657,8 @@ struct SymmetricHashJoinStream {
     null_equals_null: bool,
     /// Metrics
     metrics: SymmetricHashJoinMetrics,
+    /// Memory reservation
+    reservation: SharedMemoryReservation,
     /// Flag indicating whether there is nothing to process anymore
     final_result: bool,
 }
@@ -689,6 +711,7 @@ fn prune_hash_values(
             }
         }
     }
+    hashmap.shrink_if_necessary(HASHMAP_SHRINK_SCALE_FACTOR);
     Ok(())
 }
 
@@ -1041,12 +1064,26 @@ struct OneSideHashJoiner {
 }
 
 impl OneSideHashJoiner {
+    pub fn size(&self) -> usize {
+        let mut size = 0;
+        size += std::mem::size_of_val(self);
+        size += std::mem::size_of_val(&self.build_side);
+        size += self.input_buffer.get_array_memory_size();
+        size += std::mem::size_of_val(&self.on);
+        size += self.hashmap.size();
+        size += self.row_hash_values.capacity() * std::mem::size_of::<u64>();
+        size += self.hashes_buffer.capacity() * std::mem::size_of::<u64>();
+        size += self.visited_rows.capacity() * std::mem::size_of::<usize>();
+        size += std::mem::size_of_val(&self.offset);
+        size += std::mem::size_of_val(&self.deleted_offset);
+        size
+    }
     pub fn new(build_side: JoinSide, on: Vec<Column>, schema: SchemaRef) -> 
Self {
         Self {
             build_side,
             input_buffer: RecordBatch::new_empty(schema),
             on,
-            hashmap: JoinHashMap(RawTable::with_capacity(10_000)),
+            hashmap: JoinHashMap(RawTable::with_capacity(0)),
             row_hash_values: VecDeque::new(),
             hashes_buffer: vec![],
             visited_rows: HashSet::new(),
@@ -1074,6 +1111,7 @@ impl OneSideHashJoiner {
         self.input_buffer = concat_batches(&batch.schema(), 
[&self.input_buffer, batch])?;
         // Resize the hashes buffer to the number of rows in the incoming 
batch:
         self.hashes_buffer.resize(batch.num_rows(), 0);
+        // Get allocation_info before adding the item
         // Update the hashmap with the join key values and hashes of the 
incoming batch:
         update_hash(
             &self.on,
@@ -1339,6 +1377,24 @@ fn combine_two_batches(
 }
 
 impl SymmetricHashJoinStream {
+    fn size(&self) -> usize {
+        let mut size = 0;
+        size += std::mem::size_of_val(&self.input_stream);
+        size += std::mem::size_of_val(&self.schema);
+        size += std::mem::size_of_val(&self.filter);
+        size += std::mem::size_of_val(&self.join_type);
+        size += self.left.size();
+        size += self.right.size();
+        size += std::mem::size_of_val(&self.column_indices);
+        size += self.graph.as_ref().map(|g| g.size()).unwrap_or(0);
+        size += std::mem::size_of_val(&self.left_sorted_filter_expr);
+        size += std::mem::size_of_val(&self.right_sorted_filter_expr);
+        size += std::mem::size_of_val(&self.random_state);
+        size += std::mem::size_of_val(&self.null_equals_null);
+        size += std::mem::size_of_val(&self.metrics);
+        size += std::mem::size_of_val(&self.final_result);
+        size
+    }
     /// Polls the next result of the join operation.
     ///
     /// If the result of the join is ready, it returns the next record batch.
@@ -1442,6 +1498,9 @@ impl SymmetricHashJoinStream {
                     // Combine results:
                     let result =
                         combine_two_batches(&self.schema, equal_result, 
anti_result)?;
+                    let capacity = self.size();
+                    self.metrics.stream_memory_usage.set(capacity);
+                    self.reservation.lock().try_resize(capacity)?;
                     // Update the metrics if we have a batch; otherwise, 
continue the loop.
                     if let Some(batch) = &result {
                         self.metrics.output_batches.add(1);
@@ -1482,7 +1541,6 @@ impl SymmetricHashJoinStream {
                         // Update the metrics:
                         self.metrics.output_batches.add(1);
                         self.metrics.output_rows.add(batch.num_rows());
-
                         return Poll::Ready(Ok(result).transpose());
                     }
                 }
diff --git a/datafusion/core/src/physical_plan/streaming.rs 
b/datafusion/core/src/physical_plan/streaming.rs
index efd43aca6b..ff5d88dd31 100644
--- a/datafusion/core/src/physical_plan/streaming.rs
+++ b/datafusion/core/src/physical_plan/streaming.rs
@@ -37,6 +37,7 @@ pub struct StreamingTableExec {
     partitions: Vec<Arc<dyn PartitionStream>>,
     projection: Option<Arc<[usize]>>,
     projected_schema: SchemaRef,
+    infinite: bool,
 }
 
 impl StreamingTableExec {
@@ -45,6 +46,7 @@ impl StreamingTableExec {
         schema: SchemaRef,
         partitions: Vec<Arc<dyn PartitionStream>>,
         projection: Option<&Vec<usize>>,
+        infinite: bool,
     ) -> Result<Self> {
         if !partitions.iter().all(|x| schema.contains(x.schema())) {
             return Err(DataFusionError::Plan(
@@ -61,6 +63,7 @@ impl StreamingTableExec {
             partitions,
             projected_schema,
             projection: projection.cloned().map(Into::into),
+            infinite,
         })
     }
 }
@@ -85,6 +88,10 @@ impl ExecutionPlan for StreamingTableExec {
         Partitioning::UnknownPartitioning(self.partitions.len())
     }
 
+    fn unbounded_output(&self, _children: &[bool]) -> Result<bool> {
+        Ok(self.infinite)
+    }
+
     fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
         None
     }
diff --git a/datafusion/core/tests/memory_limit.rs 
b/datafusion/core/tests/memory_limit.rs
index 13b6291f9c..71f7f30bf0 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -17,15 +17,23 @@
 
 //! This module contains tests for limiting memory at runtime in DataFusion
 
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use futures::StreamExt;
 use std::sync::Arc;
 
+use datafusion::datasource::streaming::{PartitionStream, StreamingTable};
 use datafusion::datasource::MemTable;
 use datafusion::execution::context::SessionState;
 use datafusion::execution::disk_manager::DiskManagerConfig;
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::physical_optimizer::pipeline_fixer::PipelineFixer;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use datafusion::physical_plan::SendableRecordBatchStream;
 use datafusion_common::assert_contains;
 
 use datafusion::prelude::{SessionConfig, SessionContext};
+use datafusion_execution::TaskContext;
 use test_utils::AccessLogGenerator;
 
 #[cfg(test)]
@@ -161,6 +169,23 @@ async fn merge_join() {
     .await
 }
 
+#[tokio::test]
+async fn test_limit_symmetric_hash_join() {
+    let config = SessionConfig::new();
+
+    run_streaming_test_with_config(
+        "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = 
t2.time",
+        vec![
+            "Resources exhausted: Failed to allocate additional",
+            "SymmetricHashJoinStream",
+        ],
+        1_000,
+        config,
+    )
+    .await
+}
+
+/// 50 byte memory limit
 const MEMORY_FRACTION: f64 = 0.95;
 
 /// runs the specified query against 1000 rows with specified
@@ -219,3 +244,86 @@ async fn run_limit_test_with_config(
         }
     }
 }
+
+struct DummyStreamPartition {
+    schema: SchemaRef,
+    batches: Vec<RecordBatch>,
+}
+
+impl PartitionStream for DummyStreamPartition {
+    fn schema(&self) -> &SchemaRef {
+        &self.schema
+    }
+
+    fn execute(&self, _ctx: Arc<TaskContext>) -> SendableRecordBatchStream {
+        // We create an iterator from the record batches and map them into Ok 
values,
+        // converting the iterator into a futures::stream::Stream
+        Box::pin(RecordBatchStreamAdapter::new(
+            self.schema.clone(),
+            futures::stream::iter(self.batches.clone()).map(Ok),
+        ))
+    }
+}
+
+async fn run_streaming_test_with_config(
+    query: &str,
+    expected_error_contains: Vec<&str>,
+    memory_limit: usize,
+    config: SessionConfig,
+) {
+    // Generate a set of access logs with a row limit of 1000 and a max batch 
size of 50
+    let batches: Vec<_> = AccessLogGenerator::new()
+        .with_row_limit(1000)
+        .with_max_batch_size(50)
+        .collect();
+
+    // Create a new streaming table with the generated schema and batches
+    let table = StreamingTable::try_new(
+        batches[0].schema(),
+        vec![Arc::new(DummyStreamPartition {
+            schema: batches[0].schema(),
+            batches: batches.clone(),
+        })],
+    )
+    .unwrap()
+    .with_infinite_table(true);
+
+    // Configure the runtime environment with custom settings
+    let rt_config = RuntimeConfig::new()
+        // Disable disk manager to disallow spilling
+        .with_disk_manager(DiskManagerConfig::Disabled)
+        // Set memory limit to 50 bytes
+        .with_memory_limit(memory_limit, MEMORY_FRACTION);
+
+    // Create a new runtime environment with the configured settings
+    let runtime = RuntimeEnv::new(rt_config).unwrap();
+
+    // Create a new session state with the given configuration and runtime 
environment
+    // Disable all physical optimizer rules except the PipelineFixer rule to 
avoid sorts or
+    // repartition, as they also have memory budgets that may be hit first
+    let state = SessionState::with_config_rt(config, Arc::new(runtime))
+        .with_physical_optimizer_rules(vec![Arc::new(PipelineFixer::new())]);
+
+    // Create a new session context with the session state
+    let ctx = SessionContext::with_state(state);
+    // Register the streaming table with the session context
+    ctx.register_table("t", Arc::new(table))
+        .expect("registering table");
+
+    // Execute the SQL query and get a DataFrame
+    let df = ctx.sql(query).await.expect("Planning query");
+
+    // Collect the results of the DataFrame execution
+    match df.collect().await {
+        // If the execution succeeds, panic as we expect memory limit failure
+        Ok(_batches) => {
+            panic!("Unexpected success when running, expected memory limit 
failure")
+        }
+        // If the execution fails, verify if the error contains the expected 
substrings
+        Err(e) => {
+            for error_substring in expected_error_contains {
+                assert_contains!(e.to_string(), error_substring);
+            }
+        }
+    }
+}
diff --git a/datafusion/execution/src/memory_pool/mod.rs 
b/datafusion/execution/src/memory_pool/mod.rs
index f1f745d4eb..066cec9fb1 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -143,6 +143,17 @@ impl MemoryReservation {
         }
     }
 
+    /// Try to set the size of this reservation to `capacity`
+    pub fn try_resize(&mut self, capacity: usize) -> Result<()> {
+        use std::cmp::Ordering;
+        match capacity.cmp(&self.size) {
+            Ordering::Greater => self.try_grow(capacity - self.size)?,
+            Ordering::Less => self.shrink(self.size - capacity),
+            _ => {}
+        };
+        Ok(())
+    }
+
     /// Increase the size of this reservation by `capacity` bytes
     pub fn grow(&mut self, capacity: usize) {
         self.policy.grow(self, capacity);
diff --git a/datafusion/physical-expr/src/intervals/cp_solver.rs 
b/datafusion/physical-expr/src/intervals/cp_solver.rs
index a54bcb2f35..5e9353599e 100644
--- a/datafusion/physical-expr/src/intervals/cp_solver.rs
+++ b/datafusion/physical-expr/src/intervals/cp_solver.rs
@@ -122,6 +122,19 @@ pub struct ExprIntervalGraph {
     root: NodeIndex,
 }
 
+impl ExprIntervalGraph {
+    /// Estimate size of bytes including `Self`.
+    pub fn size(&self) -> usize {
+        let node_memory_usage = self.graph.node_count()
+            * (std::mem::size_of::<ExprIntervalGraphNode>()
+                + std::mem::size_of::<NodeIndex>());
+        let edge_memory_usage = self.graph.edge_count()
+            * (std::mem::size_of::<usize>() + std::mem::size_of::<NodeIndex>() 
* 2);
+
+        std::mem::size_of_val(self) + node_memory_usage + edge_memory_usage
+    }
+}
+
 /// This object encapsulates all possible constraint propagation results.
 #[derive(PartialEq, Debug)]
 pub enum PropagationResult {

Reply via email to