This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 612eb1d0c Memory limited nested-loop join (#5564)
612eb1d0c is described below

commit 612eb1d0ce338af7980fa906df8796eb47c4be44
Author: Eduard Karacharov <[email protected]>
AuthorDate: Tue Mar 14 16:32:58 2023 +0300

    Memory limited nested-loop join (#5564)
    
    * memory limited nl join
    
    * shared reservations as structs
---
 datafusion/core/src/physical_plan/common.rs        |   5 -
 .../core/src/physical_plan/joins/cross_join.rs     |  28 +--
 .../core/src/physical_plan/joins/hash_join.rs      |  43 ++---
 .../src/physical_plan/joins/nested_loop_join.rs    | 195 +++++++++++++++++++--
 datafusion/core/tests/memory_limit.rs              |  10 ++
 datafusion/execution/src/memory_pool/mod.rs        | 120 ++++++++++++-
 6 files changed, 329 insertions(+), 72 deletions(-)

diff --git a/datafusion/core/src/physical_plan/common.rs 
b/datafusion/core/src/physical_plan/common.rs
index 2f02aaa27..7fb67d758 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -42,11 +42,6 @@ use tokio::task::JoinHandle;
 /// [`MemoryReservation`] used across query execution streams
 pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
 
-/// [`MemoryReservation`] used at query operator level
-/// `Option` wrapper allows to initialize empty reservation in operator 
constructor,
-/// and set it to actual reservation at stream level.
-pub(crate) type OperatorMemoryReservation = 
Arc<Mutex<Option<SharedMemoryReservation>>>;
-
 /// Stream of record batches
 pub struct SizedRecordBatchStream {
     schema: SchemaRef,
diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs 
b/datafusion/core/src/physical_plan/joins/cross_join.rs
index d4933b9d6..8492e5e6b 100644
--- a/datafusion/core/src/physical_plan/joins/cross_join.rs
+++ b/datafusion/core/src/physical_plan/joins/cross_join.rs
@@ -26,8 +26,7 @@ use arrow::datatypes::{Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 
 use crate::execution::context::TaskContext;
-use crate::execution::memory_pool::MemoryConsumer;
-use crate::physical_plan::common::{OperatorMemoryReservation, 
SharedMemoryReservation};
+use crate::execution::memory_pool::{SharedOptionalMemoryReservation, TryGrow};
 use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
 use crate::physical_plan::{
     coalesce_batches::concat_batches, 
coalesce_partitions::CoalescePartitionsExec,
@@ -38,7 +37,6 @@ use crate::physical_plan::{
 use crate::{error::Result, scalar::ScalarValue};
 use async_trait::async_trait;
 use datafusion_common::DataFusionError;
-use parking_lot::Mutex;
 
 use super::utils::{
     adjust_right_output_partitioning, cross_join_equivalence_properties,
@@ -61,7 +59,7 @@ pub struct CrossJoinExec {
     /// Build-side data
     left_fut: OnceAsync<JoinLeftData>,
     /// Memory reservation for build-side data
-    reservation: OperatorMemoryReservation,
+    reservation: SharedOptionalMemoryReservation,
     /// Execution plan metrics
     metrics: ExecutionPlanMetricsSet,
 }
@@ -106,7 +104,7 @@ async fn load_left_input(
     left: Arc<dyn ExecutionPlan>,
     context: Arc<TaskContext>,
     metrics: BuildProbeJoinMetrics,
-    reservation: SharedMemoryReservation,
+    reservation: SharedOptionalMemoryReservation,
 ) -> Result<JoinLeftData> {
     // merge all left parts into a single stream
     let merge = {
@@ -125,7 +123,7 @@ async fn load_left_input(
             |mut acc, batch| async {
                 let batch_size = batch.get_array_memory_size();
                 // Reserve memory for incoming batch
-                acc.3.lock().try_grow(batch_size)?;
+                acc.3.try_grow(batch_size)?;
                 // Update metrics
                 acc.2.build_mem_used.add(batch_size);
                 acc.2.build_input_batches.add(1);
@@ -226,27 +224,15 @@ impl ExecutionPlan for CrossJoinExec {
         let join_metrics = BuildProbeJoinMetrics::new(partition, 
&self.metrics);
 
         // Initialization of operator-level reservation
-        {
-            let mut reservation_lock = self.reservation.lock();
-            if reservation_lock.is_none() {
-                *reservation_lock = Some(Arc::new(Mutex::new(
-                    
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()),
-                )));
-            };
-        }
-
-        let reservation = self.reservation.lock().clone().ok_or_else(|| {
-            DataFusionError::Internal(
-                "Operator-level memory reservation is not 
initialized".to_string(),
-            )
-        })?;
+        self.reservation
+            .initialize("CrossJoinExec", context.memory_pool());
 
         let left_fut = self.left_fut.once(|| {
             load_left_input(
                 self.left.clone(),
                 context,
                 join_metrics.clone(),
-                reservation,
+                self.reservation.clone(),
             )
         });
 
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs 
b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 0d2b897dd..39acffa20 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -58,7 +58,6 @@ use hashbrown::raw::RawTable;
 use crate::physical_plan::{
     coalesce_batches::concat_batches,
     coalesce_partitions::CoalescePartitionsExec,
-    common::{OperatorMemoryReservation, SharedMemoryReservation},
     expressions::Column,
     expressions::PhysicalSortExpr,
     hash_utils::create_hashes,
@@ -78,7 +77,12 @@ use crate::logical_expr::JoinType;
 
 use crate::arrow::array::BooleanBufferBuilder;
 use crate::arrow::datatypes::TimeUnit;
-use crate::execution::{context::TaskContext, memory_pool::MemoryConsumer};
+use crate::execution::{
+    context::TaskContext,
+    memory_pool::{
+        MemoryConsumer, SharedMemoryReservation, 
SharedOptionalMemoryReservation, TryGrow,
+    },
+};
 
 use super::{
     utils::{OnceAsync, OnceFut},
@@ -88,7 +92,6 @@ use crate::physical_plan::joins::utils::{
     adjust_indices_by_join_type, apply_join_filter_to_indices, 
build_batch_from_indices,
     get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide,
 };
-use parking_lot::Mutex;
 use std::fmt;
 use std::task::Poll;
 
@@ -137,7 +140,7 @@ pub struct HashJoinExec {
     /// Build-side data
     left_fut: OnceAsync<JoinLeftData>,
     /// Operator-level memory reservation for left data
-    reservation: OperatorMemoryReservation,
+    reservation: SharedOptionalMemoryReservation,
     /// Shares the `RandomState` for the hashing algorithm
     random_state: RandomState,
     /// Partitioning mode to use
@@ -378,26 +381,14 @@ impl ExecutionPlan for HashJoinExec {
         let join_metrics = BuildProbeJoinMetrics::new(partition, 
&self.metrics);
 
         // Initialization of operator-level reservation
-        {
-            let mut operator_reservation_lock = self.reservation.lock();
-            if operator_reservation_lock.is_none() {
-                *operator_reservation_lock = Some(Arc::new(Mutex::new(
-                    
MemoryConsumer::new("HashJoinExec").register(context.memory_pool()),
-                )));
-            };
-        }
-
-        let operator_reservation = 
self.reservation.lock().clone().ok_or_else(|| {
-            DataFusionError::Internal(
-                "Operator-level memory reservation is not 
initialized".to_string(),
-            )
-        })?;
+        self.reservation
+            .initialize("HashJoinExec", context.memory_pool());
 
         // Inititalization of stream-level reservation
-        let reservation = Arc::new(Mutex::new(
+        let reservation = SharedMemoryReservation::from(
             MemoryConsumer::new(format!("HashJoinStream[{partition}]"))
                 .register(context.memory_pool()),
-        ));
+        );
 
         // Memory reservation for left-side data depends on PartitionMode:
         // - operator-level for `CollectLeft` mode
@@ -415,7 +406,7 @@ impl ExecutionPlan for HashJoinExec {
                     on_left.clone(),
                     context.clone(),
                     join_metrics.clone(),
-                    operator_reservation.clone(),
+                    Arc::new(self.reservation.clone()),
                 )
             }),
             PartitionMode::Partitioned => OnceFut::new(collect_left_input(
@@ -425,7 +416,7 @@ impl ExecutionPlan for HashJoinExec {
                 on_left.clone(),
                 context.clone(),
                 join_metrics.clone(),
-                reservation.clone(),
+                Arc::new(reservation.clone()),
             )),
             PartitionMode::Auto => {
                 return Err(DataFusionError::Plan(format!(
@@ -497,7 +488,7 @@ async fn collect_left_input(
     on_left: Vec<Column>,
     context: Arc<TaskContext>,
     metrics: BuildProbeJoinMetrics,
-    reservation: SharedMemoryReservation,
+    reservation: Arc<dyn TryGrow>,
 ) -> Result<JoinLeftData> {
     let schema = left.schema();
 
@@ -526,7 +517,7 @@ async fn collect_left_input(
         .try_fold(initial, |mut acc, batch| async {
             let batch_size = batch.get_array_memory_size();
             // Reserve memory for incoming batch
-            acc.3.lock().try_grow(batch_size)?;
+            acc.3.try_grow(batch_size)?;
             // Update metrics
             acc.2.build_mem_used.add(batch_size);
             acc.2.build_input_batches.add(1);
@@ -555,7 +546,7 @@ async fn collect_left_input(
     // + 16 bytes fixed
     let estimated_hastable_size = 32 * estimated_buckets + estimated_buckets + 
16;
 
-    reservation.lock().try_grow(estimated_hastable_size)?;
+    reservation.try_grow(estimated_hastable_size)?;
     metrics.build_mem_used.add(estimated_hastable_size);
 
     let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
@@ -1157,7 +1148,7 @@ impl HashJoinStream {
             // TODO: Replace `ceil` wrapper with stable `div_cell` after
             // https://github.com/rust-lang/rust/issues/88581
             let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(), 
8);
-            self.reservation.lock().try_grow(visited_bitmap_size)?;
+            self.reservation.try_grow(visited_bitmap_size)?;
             self.join_metrics.build_mem_used.add(visited_bitmap_size);
         }
 
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs 
b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index c283b11f8..e04e86d0d 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -24,8 +24,10 @@ use crate::physical_plan::joins::utils::{
     build_batch_from_indices, build_join_schema, check_join_is_valid,
     combine_join_equivalence_properties, estimate_join_statistics, 
get_anti_indices,
     get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices,
-    get_semi_u64_indices, ColumnIndex, JoinFilter, JoinSide, OnceAsync, 
OnceFut,
+    get_semi_u64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter, 
JoinSide,
+    OnceAsync, OnceFut,
 };
+use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
 use crate::physical_plan::{
     DisplayFormatType, Distribution, ExecutionPlan, Partitioning, 
RecordBatchStream,
     SendableRecordBatchStream,
@@ -35,19 +37,21 @@ use arrow::array::{
 };
 use arrow::datatypes::{Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
+use arrow::util::bit_util;
 use datafusion_common::{DataFusionError, Statistics};
 use datafusion_expr::JoinType;
 use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr};
 use futures::{ready, Stream, StreamExt, TryStreamExt};
-use log::debug;
 use std::any::Any;
 use std::fmt::Formatter;
 use std::sync::Arc;
 use std::task::Poll;
-use std::time::Instant;
 
 use crate::error::Result;
 use crate::execution::context::TaskContext;
+use crate::execution::memory_pool::{
+    MemoryConsumer, SharedMemoryReservation, SharedOptionalMemoryReservation, 
TryGrow,
+};
 use crate::physical_plan::coalesce_batches::concat_batches;
 
 /// Data of the inner table side
@@ -87,6 +91,10 @@ pub struct NestedLoopJoinExec {
     inner_table: OnceAsync<JoinLeftData>,
     /// Information of index and left / right placement of columns
     column_indices: Vec<ColumnIndex>,
+    /// Operator-level memory reservation for left data
+    reservation: SharedOptionalMemoryReservation,
+    /// Execution metrics
+    metrics: ExecutionPlanMetricsSet,
 }
 
 impl NestedLoopJoinExec {
@@ -110,6 +118,8 @@ impl NestedLoopJoinExec {
             schema: Arc::new(schema),
             inner_table: Default::default(),
             column_indices,
+            reservation: Default::default(),
+            metrics: Default::default(),
         })
     }
 }
@@ -189,17 +199,41 @@ impl ExecutionPlan for NestedLoopJoinExec {
         partition: usize,
         context: Arc<TaskContext>,
     ) -> Result<SendableRecordBatchStream> {
+        let join_metrics = BuildProbeJoinMetrics::new(partition, 
&self.metrics);
+
+        // Initialization of operator-level reservation
+        self.reservation
+            .initialize("NestedLoopJoinExec", context.memory_pool());
+
+        // Inititalization of stream-level reservation
+        let reservation = SharedMemoryReservation::from(
+            MemoryConsumer::new(format!("NestedLoopJoinStream[{partition}]"))
+                .register(context.memory_pool()),
+        );
+
         let (outer_table, inner_table) = if left_is_build_side(self.join_type) 
{
             // left must be single partition
             let inner_table = self.inner_table.once(|| {
-                load_specified_partition_of_input(0, self.left.clone(), 
context.clone())
+                load_specified_partition_of_input(
+                    0,
+                    self.left.clone(),
+                    context.clone(),
+                    join_metrics.clone(),
+                    Arc::new(self.reservation.clone()),
+                )
             });
             let outer_table = self.right.execute(partition, context)?;
             (outer_table, inner_table)
         } else {
             // right must be single partition
             let inner_table = self.inner_table.once(|| {
-                load_specified_partition_of_input(0, self.right.clone(), 
context.clone())
+                load_specified_partition_of_input(
+                    0,
+                    self.right.clone(),
+                    context.clone(),
+                    join_metrics.clone(),
+                    Arc::new(self.reservation.clone()),
+                )
             });
             let outer_table = self.left.execute(partition, context)?;
             (outer_table, inner_table)
@@ -214,6 +248,8 @@ impl ExecutionPlan for NestedLoopJoinExec {
             is_exhausted: false,
             visited_left_side: None,
             column_indices: self.column_indices.clone(),
+            join_metrics,
+            reservation,
         }))
     }
 
@@ -233,6 +269,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
         }
     }
 
+    fn metrics(&self) -> Option<MetricsSet> {
+        Some(self.metrics.clone_inner())
+    }
+
     fn statistics(&self) -> Statistics {
         estimate_join_statistics(
             self.left.clone(),
@@ -273,28 +313,34 @@ async fn load_specified_partition_of_input(
     partition: usize,
     input: Arc<dyn ExecutionPlan>,
     context: Arc<TaskContext>,
+    join_metrics: BuildProbeJoinMetrics,
+    reservation: Arc<dyn TryGrow>,
 ) -> Result<JoinLeftData> {
-    let start = Instant::now();
     let stream = input.execute(partition, context)?;
 
     // Load all batches and count the rows
-    let (batches, num_rows) = stream
-        .try_fold((Vec::new(), 0usize), |mut acc, batch| async {
-            acc.1 += batch.num_rows();
-            acc.0.push(batch);
-            Ok(acc)
-        })
+    let (batches, num_rows, _, _) = stream
+        .try_fold(
+            (Vec::new(), 0usize, join_metrics, reservation),
+            |mut acc, batch| async {
+                let batch_size = batch.get_array_memory_size();
+                // Reserve memory for incoming batch
+                acc.3.try_grow(batch_size)?;
+                // Update metrics
+                acc.2.build_mem_used.add(batch_size);
+                acc.2.build_input_batches.add(1);
+                acc.2.build_input_rows.add(batch.num_rows());
+                // Update rowcount
+                acc.1 += batch.num_rows();
+                // Push batch to output
+                acc.0.push(batch);
+                Ok(acc)
+            },
+        )
         .await?;
 
     let merged_batch = concat_batches(&input.schema(), &batches, num_rows)?;
 
-    debug!(
-        "Built input of nested loop join containing {} rows in {} ms for 
partition {}",
-        num_rows,
-        start.elapsed().as_millis(),
-        partition
-    );
-
     Ok(merged_batch)
 }
 
@@ -326,6 +372,10 @@ struct NestedLoopJoinStream {
     column_indices: Vec<ColumnIndex>,
     // TODO: support null aware equal
     // null_equals_null: bool
+    /// Join execution metrics
+    join_metrics: BuildProbeJoinMetrics,
+    /// Memory reservation for visited_left_side
+    reservation: SharedMemoryReservation,
 }
 
 fn build_join_indices(
@@ -362,10 +412,20 @@ impl NestedLoopJoinStream {
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
         // all left row
+        let build_timer = self.join_metrics.build_time.timer();
         let left_data = match ready!(self.inner_table.get(cx)) {
             Ok(data) => data,
             Err(e) => return Poll::Ready(Some(Err(e))),
         };
+        build_timer.done();
+
+        if self.visited_left_side.is_none() && self.join_type == 
JoinType::Full {
+            // TODO: Replace `ceil` wrapper with stable `div_cell` after
+            // https://github.com/rust-lang/rust/issues/88581
+            let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8);
+            self.reservation.try_grow(visited_bitmap_size)?;
+            self.join_metrics.build_mem_used.add(visited_bitmap_size);
+        }
 
         // add a bitmap for full join.
         let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
@@ -384,6 +444,11 @@ impl NestedLoopJoinStream {
             .poll_next_unpin(cx)
             .map(|maybe_batch| match maybe_batch {
                 Some(Ok(right_batch)) => {
+                    // Setting up timer & updating input metrics
+                    self.join_metrics.input_batches.add(1);
+                    self.join_metrics.input_rows.add(right_batch.num_rows());
+                    let timer = self.join_metrics.join_time.timer();
+
                     let result = join_left_and_right_batch(
                         left_data,
                         &right_batch,
@@ -393,11 +458,22 @@ impl NestedLoopJoinStream {
                         &self.schema,
                         visited_left_side,
                     );
+
+                    // Recording time & updating output metrics
+                    if let Ok(batch) = &result {
+                        timer.done();
+                        self.join_metrics.output_batches.add(1);
+                        self.join_metrics.output_rows.add(batch.num_rows());
+                    }
+
                     Some(result)
                 }
                 Some(err) => Some(err),
                 None => {
                     if self.join_type == JoinType::Full && !self.is_exhausted {
+                        // Only setting up timer, input is exhausted
+                        let timer = self.join_metrics.join_time.timer();
+
                         // use the global left bitmap to produce the left 
indices and right indices
                         let (left_side, right_side) = 
get_final_indices_from_bit_map(
                             visited_left_side,
@@ -416,6 +492,14 @@ impl NestedLoopJoinStream {
                             JoinSide::Left,
                         );
                         self.is_exhausted = true;
+
+                        // Recording time & updating output metrics
+                        if let Ok(batch) = &result {
+                            timer.done();
+                            self.join_metrics.output_batches.add(1);
+                            
self.join_metrics.output_rows.add(batch.num_rows());
+                        }
+
                         Some(result)
                     } else {
                         // end of the join loop
@@ -431,10 +515,12 @@ impl NestedLoopJoinStream {
         cx: &mut std::task::Context<'_>,
     ) -> Poll<Option<Result<RecordBatch>>> {
         // all right row
+        let build_timer = self.join_metrics.build_time.timer();
         let right_data = match ready!(self.inner_table.get(cx)) {
             Ok(data) => data,
             Err(e) => return Poll::Ready(Some(Err(e))),
         };
+        build_timer.done();
 
         // for build right, bitmap is not needed.
         let mut empty_visited_left_side = BooleanBufferBuilder::new(0);
@@ -442,6 +528,12 @@ impl NestedLoopJoinStream {
             .poll_next_unpin(cx)
             .map(|maybe_batch| match maybe_batch {
                 Some(Ok(left_batch)) => {
+                    // Setting up timer & updating input metrics
+                    self.join_metrics.input_batches.add(1);
+                    self.join_metrics.input_rows.add(left_batch.num_rows());
+                    let timer = self.join_metrics.join_time.timer();
+
+                    // Actual join execution
                     let result = join_left_and_right_batch(
                         &left_batch,
                         right_data,
@@ -451,6 +543,14 @@ impl NestedLoopJoinStream {
                         &self.schema,
                         &mut empty_visited_left_side,
                     );
+
+                    // Recording time & updating output metrics
+                    if let Ok(batch) = &result {
+                        timer.done();
+                        self.join_metrics.output_batches.add(1);
+                        self.join_metrics.output_rows.add(batch.num_rows());
+                    }
+
                     Some(result)
                 }
                 Some(err) => Some(err),
@@ -633,6 +733,11 @@ mod tests {
     use crate::physical_expr::expressions::BinaryExpr;
     use crate::{
         assert_batches_sorted_eq,
+        common::assert_contains,
+        execution::{
+            context::SessionConfig,
+            runtime_env::{RuntimeConfig, RuntimeEnv},
+        },
         physical_plan::{
             common, expressions::Column, memory::MemoryExec, 
repartition::RepartitionExec,
         },
@@ -1016,4 +1121,56 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_overallocation() -> Result<()> {
+        let left = build_table(
+            ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+            ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+            ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+        );
+        let right = build_table(
+            ("a2", &vec![10, 11]),
+            ("b2", &vec![12, 13]),
+            ("c2", &vec![14, 15]),
+        );
+        let filter = prepare_join_filter();
+
+        let join_types = vec![
+            JoinType::Inner,
+            JoinType::Left,
+            JoinType::Right,
+            JoinType::Full,
+            JoinType::LeftSemi,
+            JoinType::LeftAnti,
+            JoinType::RightSemi,
+            JoinType::RightAnti,
+        ];
+
+        for join_type in join_types {
+            let runtime_config = RuntimeConfig::new().with_memory_limit(100, 
1.0);
+            let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+            let session_ctx =
+                SessionContext::with_config_rt(SessionConfig::default(), 
runtime);
+            let task_ctx = session_ctx.task_ctx();
+
+            let err = multi_partitioned_join_collect(
+                left.clone(),
+                right.clone(),
+                &join_type,
+                Some(filter.clone()),
+                task_ctx,
+            )
+            .await
+            .unwrap_err();
+
+            assert_contains!(
+                err.to_string(),
+                "External error: Resources exhausted: Failed to allocate 
additional"
+            );
+            assert_contains!(err.to_string(), "NestedLoopJoinExec");
+        }
+
+        Ok(())
+    }
 }
diff --git a/datafusion/core/tests/memory_limit.rs 
b/datafusion/core/tests/memory_limit.rs
index 392e54941..9de8aff77 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -84,6 +84,16 @@ async fn join_by_key() {
     .await
 }
 
+#[tokio::test]
+async fn join_by_expression() {
+    run_limit_test(
+        "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service",
+        "Resources exhausted: Failed to allocate additional",
+        1_000,
+    )
+    .await
+}
+
 #[tokio::test]
 async fn cross_join() {
     run_limit_test(
diff --git a/datafusion/execution/src/memory_pool/mod.rs 
b/datafusion/execution/src/memory_pool/mod.rs
index f68a25650..f1f745d4e 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -17,7 +17,8 @@
 
 //! Manages all available memory during query execution
 
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
+use parking_lot::Mutex;
 use std::sync::Arc;
 
 mod pool;
@@ -163,6 +164,64 @@ impl Drop for MemoryReservation {
     }
 }
 
+pub trait TryGrow: Send + Sync + std::fmt::Debug {
+    fn try_grow(&self, capacity: usize) -> Result<()>;
+}
+
+/// Cloneable reference to [`MemoryReservation`] instance with interior 
mutability support
+#[derive(Clone, Debug)]
+pub struct SharedMemoryReservation(Arc<Mutex<MemoryReservation>>);
+
+impl From<MemoryReservation> for SharedMemoryReservation {
+    /// Creates new [`SharedMemoryReservation`] from [`MemoryReservation`]
+    fn from(reservation: MemoryReservation) -> Self {
+        Self(Arc::new(Mutex::new(reservation)))
+    }
+}
+
+impl TryGrow for SharedMemoryReservation {
+    /// Try to increase the size of this reservation by `capacity` bytes
+    fn try_grow(&self, capacity: usize) -> Result<()> {
+        self.0.lock().try_grow(capacity)
+    }
+}
+
+/// Cloneable reference to [`MemoryReservation`] instance with interior 
mutability support.
+/// Doesn't require [`MemoryReservation`] while creation, and can be 
initialized later.
+#[derive(Clone, Debug)]
+pub struct 
SharedOptionalMemoryReservation(Arc<Mutex<Option<MemoryReservation>>>);
+
+impl SharedOptionalMemoryReservation {
+    /// Initialize inner [`MemoryReservation`] if `None`, otherwise -- do 
nothing
+    pub fn initialize(&self, name: impl Into<String>, pool: &Arc<dyn 
MemoryPool>) {
+        let mut locked = self.0.lock();
+        if locked.is_none() {
+            *locked = Some(MemoryConsumer::new(name).register(pool));
+        };
+    }
+}
+
+impl TryGrow for SharedOptionalMemoryReservation {
+    /// Try to increase the size of this reservation by `capacity` bytes
+    fn try_grow(&self, capacity: usize) -> Result<()> {
+        self.0
+            .lock()
+            .as_mut()
+            .ok_or_else(|| {
+                DataFusionError::Internal(
+                    "inner memory reservation not initialized".to_string(),
+                )
+            })?
+            .try_grow(capacity)
+    }
+}
+
+impl Default for SharedOptionalMemoryReservation {
+    fn default() -> Self {
+        Self(Arc::new(Mutex::new(None)))
+    }
+}
+
 const TB: u64 = 1 << 40;
 const GB: u64 = 1 << 30;
 const MB: u64 = 1 << 20;
@@ -219,4 +278,63 @@ mod tests {
         a2.try_grow(25).unwrap();
         assert_eq!(pool.reserved(), 25);
     }
+
+    #[test]
+    fn test_shared_memory_reservation() {
+        let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
+        let a1 = 
SharedMemoryReservation::from(MemoryConsumer::new("a1").register(&pool));
+        let a2 = a1.clone();
+
+        // Reserve from a1
+        a1.try_grow(10).unwrap();
+        assert_eq!(pool.reserved(), 10);
+
+        // Drop a1 - normally reservation calls `free` on drop.
+        // Ensure that reservation still alive in a2
+        drop(a1);
+        assert_eq!(pool.reserved(), 10);
+
+        // Ensure that after a2 dropped, memory gets back to the pool
+        drop(a2);
+        assert_eq!(pool.reserved(), 0);
+    }
+
+    #[test]
+    fn test_optional_shared_memory_reservation() {
+        let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
+        let a1 = SharedOptionalMemoryReservation::default();
+
+        // try_grow on empty inner reservation
+        let err = a1.try_grow(10).unwrap_err();
+        assert_eq!(
+            err.to_string(),
+            "Internal error: inner memory reservation not initialized. \
+             This was likely caused by a bug in DataFusion's code and we \
+             would welcome that you file an bug report in our issue tracker"
+        );
+
+        // multiple initializations
+        a1.initialize("a1", &pool);
+        a1.initialize("a2", &pool);
+        {
+            let locked = a1.0.lock();
+            let name = locked.as_ref().unwrap().consumer.name();
+            assert_eq!(name, "a1");
+        }
+
+        let a2 = a1.clone();
+
+        // Reserve from a1
+        a1.try_grow(10).unwrap();
+        assert_eq!(pool.reserved(), 10);
+
+        // Drop a1 - normally reservation calls `free` on drop.
+        // Ensure that reservation still alive in a2
+        drop(a1);
+        assert_eq!(pool.reserved(), 10);
+
+        // Ensure that after a2 dropped, memory gets back to the pool
+        drop(a2);
+        assert_eq!(pool.reserved(), 0);
+    }
 }

Reply via email to