This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 612eb1d0c Memory limited nested-loop join (#5564)
612eb1d0c is described below
commit 612eb1d0ce338af7980fa906df8796eb47c4be44
Author: Eduard Karacharov <[email protected]>
AuthorDate: Tue Mar 14 16:32:58 2023 +0300
Memory limited nested-loop join (#5564)
* memory limited nl join
* shared reservations as structs
---
datafusion/core/src/physical_plan/common.rs | 5 -
.../core/src/physical_plan/joins/cross_join.rs | 28 +--
.../core/src/physical_plan/joins/hash_join.rs | 43 ++---
.../src/physical_plan/joins/nested_loop_join.rs | 195 +++++++++++++++++++--
datafusion/core/tests/memory_limit.rs | 10 ++
datafusion/execution/src/memory_pool/mod.rs | 120 ++++++++++++-
6 files changed, 329 insertions(+), 72 deletions(-)
diff --git a/datafusion/core/src/physical_plan/common.rs
b/datafusion/core/src/physical_plan/common.rs
index 2f02aaa27..7fb67d758 100644
--- a/datafusion/core/src/physical_plan/common.rs
+++ b/datafusion/core/src/physical_plan/common.rs
@@ -42,11 +42,6 @@ use tokio::task::JoinHandle;
/// [`MemoryReservation`] used across query execution streams
pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
-/// [`MemoryReservation`] used at query operator level
-/// `Option` wrapper allows to initialize empty reservation in operator
constructor,
-/// and set it to actual reservation at stream level.
-pub(crate) type OperatorMemoryReservation =
Arc<Mutex<Option<SharedMemoryReservation>>>;
-
/// Stream of record batches
pub struct SizedRecordBatchStream {
schema: SchemaRef,
diff --git a/datafusion/core/src/physical_plan/joins/cross_join.rs
b/datafusion/core/src/physical_plan/joins/cross_join.rs
index d4933b9d6..8492e5e6b 100644
--- a/datafusion/core/src/physical_plan/joins/cross_join.rs
+++ b/datafusion/core/src/physical_plan/joins/cross_join.rs
@@ -26,8 +26,7 @@ use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
use crate::execution::context::TaskContext;
-use crate::execution::memory_pool::MemoryConsumer;
-use crate::physical_plan::common::{OperatorMemoryReservation,
SharedMemoryReservation};
+use crate::execution::memory_pool::{SharedOptionalMemoryReservation, TryGrow};
use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
coalesce_batches::concat_batches,
coalesce_partitions::CoalescePartitionsExec,
@@ -38,7 +37,6 @@ use crate::physical_plan::{
use crate::{error::Result, scalar::ScalarValue};
use async_trait::async_trait;
use datafusion_common::DataFusionError;
-use parking_lot::Mutex;
use super::utils::{
adjust_right_output_partitioning, cross_join_equivalence_properties,
@@ -61,7 +59,7 @@ pub struct CrossJoinExec {
/// Build-side data
left_fut: OnceAsync<JoinLeftData>,
/// Memory reservation for build-side data
- reservation: OperatorMemoryReservation,
+ reservation: SharedOptionalMemoryReservation,
/// Execution plan metrics
metrics: ExecutionPlanMetricsSet,
}
@@ -106,7 +104,7 @@ async fn load_left_input(
left: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
metrics: BuildProbeJoinMetrics,
- reservation: SharedMemoryReservation,
+ reservation: SharedOptionalMemoryReservation,
) -> Result<JoinLeftData> {
// merge all left parts into a single stream
let merge = {
@@ -125,7 +123,7 @@ async fn load_left_input(
|mut acc, batch| async {
let batch_size = batch.get_array_memory_size();
// Reserve memory for incoming batch
- acc.3.lock().try_grow(batch_size)?;
+ acc.3.try_grow(batch_size)?;
// Update metrics
acc.2.build_mem_used.add(batch_size);
acc.2.build_input_batches.add(1);
@@ -226,27 +224,15 @@ impl ExecutionPlan for CrossJoinExec {
let join_metrics = BuildProbeJoinMetrics::new(partition,
&self.metrics);
// Initialization of operator-level reservation
- {
- let mut reservation_lock = self.reservation.lock();
- if reservation_lock.is_none() {
- *reservation_lock = Some(Arc::new(Mutex::new(
-
MemoryConsumer::new("CrossJoinExec").register(context.memory_pool()),
- )));
- };
- }
-
- let reservation = self.reservation.lock().clone().ok_or_else(|| {
- DataFusionError::Internal(
- "Operator-level memory reservation is not
initialized".to_string(),
- )
- })?;
+ self.reservation
+ .initialize("CrossJoinExec", context.memory_pool());
let left_fut = self.left_fut.once(|| {
load_left_input(
self.left.clone(),
context,
join_metrics.clone(),
- reservation,
+ self.reservation.clone(),
)
});
diff --git a/datafusion/core/src/physical_plan/joins/hash_join.rs
b/datafusion/core/src/physical_plan/joins/hash_join.rs
index 0d2b897dd..39acffa20 100644
--- a/datafusion/core/src/physical_plan/joins/hash_join.rs
+++ b/datafusion/core/src/physical_plan/joins/hash_join.rs
@@ -58,7 +58,6 @@ use hashbrown::raw::RawTable;
use crate::physical_plan::{
coalesce_batches::concat_batches,
coalesce_partitions::CoalescePartitionsExec,
- common::{OperatorMemoryReservation, SharedMemoryReservation},
expressions::Column,
expressions::PhysicalSortExpr,
hash_utils::create_hashes,
@@ -78,7 +77,12 @@ use crate::logical_expr::JoinType;
use crate::arrow::array::BooleanBufferBuilder;
use crate::arrow::datatypes::TimeUnit;
-use crate::execution::{context::TaskContext, memory_pool::MemoryConsumer};
+use crate::execution::{
+ context::TaskContext,
+ memory_pool::{
+ MemoryConsumer, SharedMemoryReservation,
SharedOptionalMemoryReservation, TryGrow,
+ },
+};
use super::{
utils::{OnceAsync, OnceFut},
@@ -88,7 +92,6 @@ use crate::physical_plan::joins::utils::{
adjust_indices_by_join_type, apply_join_filter_to_indices,
build_batch_from_indices,
get_final_indices_from_bit_map, need_produce_result_in_final, JoinSide,
};
-use parking_lot::Mutex;
use std::fmt;
use std::task::Poll;
@@ -137,7 +140,7 @@ pub struct HashJoinExec {
/// Build-side data
left_fut: OnceAsync<JoinLeftData>,
/// Operator-level memory reservation for left data
- reservation: OperatorMemoryReservation,
+ reservation: SharedOptionalMemoryReservation,
/// Shares the `RandomState` for the hashing algorithm
random_state: RandomState,
/// Partitioning mode to use
@@ -378,26 +381,14 @@ impl ExecutionPlan for HashJoinExec {
let join_metrics = BuildProbeJoinMetrics::new(partition,
&self.metrics);
// Initialization of operator-level reservation
- {
- let mut operator_reservation_lock = self.reservation.lock();
- if operator_reservation_lock.is_none() {
- *operator_reservation_lock = Some(Arc::new(Mutex::new(
-
MemoryConsumer::new("HashJoinExec").register(context.memory_pool()),
- )));
- };
- }
-
- let operator_reservation =
self.reservation.lock().clone().ok_or_else(|| {
- DataFusionError::Internal(
- "Operator-level memory reservation is not
initialized".to_string(),
- )
- })?;
+ self.reservation
+ .initialize("HashJoinExec", context.memory_pool());
// Inititalization of stream-level reservation
- let reservation = Arc::new(Mutex::new(
+ let reservation = SharedMemoryReservation::from(
MemoryConsumer::new(format!("HashJoinStream[{partition}]"))
.register(context.memory_pool()),
- ));
+ );
// Memory reservation for left-side data depends on PartitionMode:
// - operator-level for `CollectLeft` mode
@@ -415,7 +406,7 @@ impl ExecutionPlan for HashJoinExec {
on_left.clone(),
context.clone(),
join_metrics.clone(),
- operator_reservation.clone(),
+ Arc::new(self.reservation.clone()),
)
}),
PartitionMode::Partitioned => OnceFut::new(collect_left_input(
@@ -425,7 +416,7 @@ impl ExecutionPlan for HashJoinExec {
on_left.clone(),
context.clone(),
join_metrics.clone(),
- reservation.clone(),
+ Arc::new(reservation.clone()),
)),
PartitionMode::Auto => {
return Err(DataFusionError::Plan(format!(
@@ -497,7 +488,7 @@ async fn collect_left_input(
on_left: Vec<Column>,
context: Arc<TaskContext>,
metrics: BuildProbeJoinMetrics,
- reservation: SharedMemoryReservation,
+ reservation: Arc<dyn TryGrow>,
) -> Result<JoinLeftData> {
let schema = left.schema();
@@ -526,7 +517,7 @@ async fn collect_left_input(
.try_fold(initial, |mut acc, batch| async {
let batch_size = batch.get_array_memory_size();
// Reserve memory for incoming batch
- acc.3.lock().try_grow(batch_size)?;
+ acc.3.try_grow(batch_size)?;
// Update metrics
acc.2.build_mem_used.add(batch_size);
acc.2.build_input_batches.add(1);
@@ -555,7 +546,7 @@ async fn collect_left_input(
// + 16 bytes fixed
let estimated_hastable_size = 32 * estimated_buckets + estimated_buckets +
16;
- reservation.lock().try_grow(estimated_hastable_size)?;
+ reservation.try_grow(estimated_hastable_size)?;
metrics.build_mem_used.add(estimated_hastable_size);
let mut hashmap = JoinHashMap(RawTable::with_capacity(num_rows));
@@ -1157,7 +1148,7 @@ impl HashJoinStream {
// TODO: Replace `ceil` wrapper with stable `div_cell` after
// https://github.com/rust-lang/rust/issues/88581
let visited_bitmap_size = bit_util::ceil(left_data.1.num_rows(),
8);
- self.reservation.lock().try_grow(visited_bitmap_size)?;
+ self.reservation.try_grow(visited_bitmap_size)?;
self.join_metrics.build_mem_used.add(visited_bitmap_size);
}
diff --git a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
index c283b11f8..e04e86d0d 100644
--- a/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
+++ b/datafusion/core/src/physical_plan/joins/nested_loop_join.rs
@@ -24,8 +24,10 @@ use crate::physical_plan::joins::utils::{
build_batch_from_indices, build_join_schema, check_join_is_valid,
combine_join_equivalence_properties, estimate_join_statistics,
get_anti_indices,
get_anti_u64_indices, get_final_indices_from_bit_map, get_semi_indices,
- get_semi_u64_indices, ColumnIndex, JoinFilter, JoinSide, OnceAsync,
OnceFut,
+ get_semi_u64_indices, BuildProbeJoinMetrics, ColumnIndex, JoinFilter,
JoinSide,
+ OnceAsync, OnceFut,
};
+use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MetricsSet};
use crate::physical_plan::{
DisplayFormatType, Distribution, ExecutionPlan, Partitioning,
RecordBatchStream,
SendableRecordBatchStream,
@@ -35,19 +37,21 @@ use arrow::array::{
};
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
+use arrow::util::bit_util;
use datafusion_common::{DataFusionError, Statistics};
use datafusion_expr::JoinType;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalSortExpr};
use futures::{ready, Stream, StreamExt, TryStreamExt};
-use log::debug;
use std::any::Any;
use std::fmt::Formatter;
use std::sync::Arc;
use std::task::Poll;
-use std::time::Instant;
use crate::error::Result;
use crate::execution::context::TaskContext;
+use crate::execution::memory_pool::{
+ MemoryConsumer, SharedMemoryReservation, SharedOptionalMemoryReservation,
TryGrow,
+};
use crate::physical_plan::coalesce_batches::concat_batches;
/// Data of the inner table side
@@ -87,6 +91,10 @@ pub struct NestedLoopJoinExec {
inner_table: OnceAsync<JoinLeftData>,
/// Information of index and left / right placement of columns
column_indices: Vec<ColumnIndex>,
+ /// Operator-level memory reservation for left data
+ reservation: SharedOptionalMemoryReservation,
+ /// Execution metrics
+ metrics: ExecutionPlanMetricsSet,
}
impl NestedLoopJoinExec {
@@ -110,6 +118,8 @@ impl NestedLoopJoinExec {
schema: Arc::new(schema),
inner_table: Default::default(),
column_indices,
+ reservation: Default::default(),
+ metrics: Default::default(),
})
}
}
@@ -189,17 +199,41 @@ impl ExecutionPlan for NestedLoopJoinExec {
partition: usize,
context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
+ let join_metrics = BuildProbeJoinMetrics::new(partition,
&self.metrics);
+
+ // Initialization of operator-level reservation
+ self.reservation
+ .initialize("NestedLoopJoinExec", context.memory_pool());
+
+ // Inititalization of stream-level reservation
+ let reservation = SharedMemoryReservation::from(
+ MemoryConsumer::new(format!("NestedLoopJoinStream[{partition}]"))
+ .register(context.memory_pool()),
+ );
+
let (outer_table, inner_table) = if left_is_build_side(self.join_type)
{
// left must be single partition
let inner_table = self.inner_table.once(|| {
- load_specified_partition_of_input(0, self.left.clone(),
context.clone())
+ load_specified_partition_of_input(
+ 0,
+ self.left.clone(),
+ context.clone(),
+ join_metrics.clone(),
+ Arc::new(self.reservation.clone()),
+ )
});
let outer_table = self.right.execute(partition, context)?;
(outer_table, inner_table)
} else {
// right must be single partition
let inner_table = self.inner_table.once(|| {
- load_specified_partition_of_input(0, self.right.clone(),
context.clone())
+ load_specified_partition_of_input(
+ 0,
+ self.right.clone(),
+ context.clone(),
+ join_metrics.clone(),
+ Arc::new(self.reservation.clone()),
+ )
});
let outer_table = self.left.execute(partition, context)?;
(outer_table, inner_table)
@@ -214,6 +248,8 @@ impl ExecutionPlan for NestedLoopJoinExec {
is_exhausted: false,
visited_left_side: None,
column_indices: self.column_indices.clone(),
+ join_metrics,
+ reservation,
}))
}
@@ -233,6 +269,10 @@ impl ExecutionPlan for NestedLoopJoinExec {
}
}
+ fn metrics(&self) -> Option<MetricsSet> {
+ Some(self.metrics.clone_inner())
+ }
+
fn statistics(&self) -> Statistics {
estimate_join_statistics(
self.left.clone(),
@@ -273,28 +313,34 @@ async fn load_specified_partition_of_input(
partition: usize,
input: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
+ join_metrics: BuildProbeJoinMetrics,
+ reservation: Arc<dyn TryGrow>,
) -> Result<JoinLeftData> {
- let start = Instant::now();
let stream = input.execute(partition, context)?;
// Load all batches and count the rows
- let (batches, num_rows) = stream
- .try_fold((Vec::new(), 0usize), |mut acc, batch| async {
- acc.1 += batch.num_rows();
- acc.0.push(batch);
- Ok(acc)
- })
+ let (batches, num_rows, _, _) = stream
+ .try_fold(
+ (Vec::new(), 0usize, join_metrics, reservation),
+ |mut acc, batch| async {
+ let batch_size = batch.get_array_memory_size();
+ // Reserve memory for incoming batch
+ acc.3.try_grow(batch_size)?;
+ // Update metrics
+ acc.2.build_mem_used.add(batch_size);
+ acc.2.build_input_batches.add(1);
+ acc.2.build_input_rows.add(batch.num_rows());
+ // Update rowcount
+ acc.1 += batch.num_rows();
+ // Push batch to output
+ acc.0.push(batch);
+ Ok(acc)
+ },
+ )
.await?;
let merged_batch = concat_batches(&input.schema(), &batches, num_rows)?;
- debug!(
- "Built input of nested loop join containing {} rows in {} ms for
partition {}",
- num_rows,
- start.elapsed().as_millis(),
- partition
- );
-
Ok(merged_batch)
}
@@ -326,6 +372,10 @@ struct NestedLoopJoinStream {
column_indices: Vec<ColumnIndex>,
// TODO: support null aware equal
// null_equals_null: bool
+ /// Join execution metrics
+ join_metrics: BuildProbeJoinMetrics,
+ /// Memory reservation for visited_left_side
+ reservation: SharedMemoryReservation,
}
fn build_join_indices(
@@ -362,10 +412,20 @@ impl NestedLoopJoinStream {
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
// all left row
+ let build_timer = self.join_metrics.build_time.timer();
let left_data = match ready!(self.inner_table.get(cx)) {
Ok(data) => data,
Err(e) => return Poll::Ready(Some(Err(e))),
};
+ build_timer.done();
+
+ if self.visited_left_side.is_none() && self.join_type ==
JoinType::Full {
+ // TODO: Replace `ceil` wrapper with stable `div_cell` after
+ // https://github.com/rust-lang/rust/issues/88581
+ let visited_bitmap_size = bit_util::ceil(left_data.num_rows(), 8);
+ self.reservation.try_grow(visited_bitmap_size)?;
+ self.join_metrics.build_mem_used.add(visited_bitmap_size);
+ }
// add a bitmap for full join.
let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
@@ -384,6 +444,11 @@ impl NestedLoopJoinStream {
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
Some(Ok(right_batch)) => {
+ // Setting up timer & updating input metrics
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(right_batch.num_rows());
+ let timer = self.join_metrics.join_time.timer();
+
let result = join_left_and_right_batch(
left_data,
&right_batch,
@@ -393,11 +458,22 @@ impl NestedLoopJoinStream {
&self.schema,
visited_left_side,
);
+
+ // Recording time & updating output metrics
+ if let Ok(batch) = &result {
+ timer.done();
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ }
+
Some(result)
}
Some(err) => Some(err),
None => {
if self.join_type == JoinType::Full && !self.is_exhausted {
+ // Only setting up timer, input is exhausted
+ let timer = self.join_metrics.join_time.timer();
+
// use the global left bitmap to produce the left
indices and right indices
let (left_side, right_side) =
get_final_indices_from_bit_map(
visited_left_side,
@@ -416,6 +492,14 @@ impl NestedLoopJoinStream {
JoinSide::Left,
);
self.is_exhausted = true;
+
+ // Recording time & updating output metrics
+ if let Ok(batch) = &result {
+ timer.done();
+ self.join_metrics.output_batches.add(1);
+
self.join_metrics.output_rows.add(batch.num_rows());
+ }
+
Some(result)
} else {
// end of the join loop
@@ -431,10 +515,12 @@ impl NestedLoopJoinStream {
cx: &mut std::task::Context<'_>,
) -> Poll<Option<Result<RecordBatch>>> {
// all right row
+ let build_timer = self.join_metrics.build_time.timer();
let right_data = match ready!(self.inner_table.get(cx)) {
Ok(data) => data,
Err(e) => return Poll::Ready(Some(Err(e))),
};
+ build_timer.done();
// for build right, bitmap is not needed.
let mut empty_visited_left_side = BooleanBufferBuilder::new(0);
@@ -442,6 +528,12 @@ impl NestedLoopJoinStream {
.poll_next_unpin(cx)
.map(|maybe_batch| match maybe_batch {
Some(Ok(left_batch)) => {
+ // Setting up timer & updating input metrics
+ self.join_metrics.input_batches.add(1);
+ self.join_metrics.input_rows.add(left_batch.num_rows());
+ let timer = self.join_metrics.join_time.timer();
+
+ // Actual join execution
let result = join_left_and_right_batch(
&left_batch,
right_data,
@@ -451,6 +543,14 @@ impl NestedLoopJoinStream {
&self.schema,
&mut empty_visited_left_side,
);
+
+ // Recording time & updating output metrics
+ if let Ok(batch) = &result {
+ timer.done();
+ self.join_metrics.output_batches.add(1);
+ self.join_metrics.output_rows.add(batch.num_rows());
+ }
+
Some(result)
}
Some(err) => Some(err),
@@ -633,6 +733,11 @@ mod tests {
use crate::physical_expr::expressions::BinaryExpr;
use crate::{
assert_batches_sorted_eq,
+ common::assert_contains,
+ execution::{
+ context::SessionConfig,
+ runtime_env::{RuntimeConfig, RuntimeEnv},
+ },
physical_plan::{
common, expressions::Column, memory::MemoryExec,
repartition::RepartitionExec,
},
@@ -1016,4 +1121,56 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_overallocation() -> Result<()> {
+ let left = build_table(
+ ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+ ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+ ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
+ );
+ let right = build_table(
+ ("a2", &vec![10, 11]),
+ ("b2", &vec![12, 13]),
+ ("c2", &vec![14, 15]),
+ );
+ let filter = prepare_join_filter();
+
+ let join_types = vec![
+ JoinType::Inner,
+ JoinType::Left,
+ JoinType::Right,
+ JoinType::Full,
+ JoinType::LeftSemi,
+ JoinType::LeftAnti,
+ JoinType::RightSemi,
+ JoinType::RightAnti,
+ ];
+
+ for join_type in join_types {
+ let runtime_config = RuntimeConfig::new().with_memory_limit(100,
1.0);
+ let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
+ let session_ctx =
+ SessionContext::with_config_rt(SessionConfig::default(),
runtime);
+ let task_ctx = session_ctx.task_ctx();
+
+ let err = multi_partitioned_join_collect(
+ left.clone(),
+ right.clone(),
+ &join_type,
+ Some(filter.clone()),
+ task_ctx,
+ )
+ .await
+ .unwrap_err();
+
+ assert_contains!(
+ err.to_string(),
+ "External error: Resources exhausted: Failed to allocate
additional"
+ );
+ assert_contains!(err.to_string(), "NestedLoopJoinExec");
+ }
+
+ Ok(())
+ }
}
diff --git a/datafusion/core/tests/memory_limit.rs
b/datafusion/core/tests/memory_limit.rs
index 392e54941..9de8aff77 100644
--- a/datafusion/core/tests/memory_limit.rs
+++ b/datafusion/core/tests/memory_limit.rs
@@ -84,6 +84,16 @@ async fn join_by_key() {
.await
}
+#[tokio::test]
+async fn join_by_expression() {
+ run_limit_test(
+ "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service",
+ "Resources exhausted: Failed to allocate additional",
+ 1_000,
+ )
+ .await
+}
+
#[tokio::test]
async fn cross_join() {
run_limit_test(
diff --git a/datafusion/execution/src/memory_pool/mod.rs
b/datafusion/execution/src/memory_pool/mod.rs
index f68a25650..f1f745d4e 100644
--- a/datafusion/execution/src/memory_pool/mod.rs
+++ b/datafusion/execution/src/memory_pool/mod.rs
@@ -17,7 +17,8 @@
//! Manages all available memory during query execution
-use datafusion_common::Result;
+use datafusion_common::{DataFusionError, Result};
+use parking_lot::Mutex;
use std::sync::Arc;
mod pool;
@@ -163,6 +164,64 @@ impl Drop for MemoryReservation {
}
}
+pub trait TryGrow: Send + Sync + std::fmt::Debug {
+ fn try_grow(&self, capacity: usize) -> Result<()>;
+}
+
+/// Cloneable reference to [`MemoryReservation`] instance with interior
mutability support
+#[derive(Clone, Debug)]
+pub struct SharedMemoryReservation(Arc<Mutex<MemoryReservation>>);
+
+impl From<MemoryReservation> for SharedMemoryReservation {
+ /// Creates new [`SharedMemoryReservation`] from [`MemoryReservation`]
+ fn from(reservation: MemoryReservation) -> Self {
+ Self(Arc::new(Mutex::new(reservation)))
+ }
+}
+
+impl TryGrow for SharedMemoryReservation {
+ /// Try to increase the size of this reservation by `capacity` bytes
+ fn try_grow(&self, capacity: usize) -> Result<()> {
+ self.0.lock().try_grow(capacity)
+ }
+}
+
+/// Cloneable reference to [`MemoryReservation`] instance with interior
mutability support.
+/// Doesn't require [`MemoryReservation`] while creation, and can be
initialized later.
+#[derive(Clone, Debug)]
+pub struct
SharedOptionalMemoryReservation(Arc<Mutex<Option<MemoryReservation>>>);
+
+impl SharedOptionalMemoryReservation {
+ /// Initialize inner [`MemoryReservation`] if `None`, otherwise -- do
nothing
+ pub fn initialize(&self, name: impl Into<String>, pool: &Arc<dyn
MemoryPool>) {
+ let mut locked = self.0.lock();
+ if locked.is_none() {
+ *locked = Some(MemoryConsumer::new(name).register(pool));
+ };
+ }
+}
+
+impl TryGrow for SharedOptionalMemoryReservation {
+ /// Try to increase the size of this reservation by `capacity` bytes
+ fn try_grow(&self, capacity: usize) -> Result<()> {
+ self.0
+ .lock()
+ .as_mut()
+ .ok_or_else(|| {
+ DataFusionError::Internal(
+ "inner memory reservation not initialized".to_string(),
+ )
+ })?
+ .try_grow(capacity)
+ }
+}
+
+impl Default for SharedOptionalMemoryReservation {
+ fn default() -> Self {
+ Self(Arc::new(Mutex::new(None)))
+ }
+}
+
const TB: u64 = 1 << 40;
const GB: u64 = 1 << 30;
const MB: u64 = 1 << 20;
@@ -219,4 +278,63 @@ mod tests {
a2.try_grow(25).unwrap();
assert_eq!(pool.reserved(), 25);
}
+
+ #[test]
+ fn test_shared_memory_reservation() {
+ let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
+ let a1 =
SharedMemoryReservation::from(MemoryConsumer::new("a1").register(&pool));
+ let a2 = a1.clone();
+
+ // Reserve from a1
+ a1.try_grow(10).unwrap();
+ assert_eq!(pool.reserved(), 10);
+
+ // Drop a1 - normally reservation calls `free` on drop.
+ // Ensure that reservation still alive in a2
+ drop(a1);
+ assert_eq!(pool.reserved(), 10);
+
+ // Ensure that after a2 dropped, memory gets back to the pool
+ drop(a2);
+ assert_eq!(pool.reserved(), 0);
+ }
+
+ #[test]
+ fn test_optional_shared_memory_reservation() {
+ let pool = Arc::new(GreedyMemoryPool::new(50)) as _;
+ let a1 = SharedOptionalMemoryReservation::default();
+
+ // try_grow on empty inner reservation
+ let err = a1.try_grow(10).unwrap_err();
+ assert_eq!(
+ err.to_string(),
+ "Internal error: inner memory reservation not initialized. \
+ This was likely caused by a bug in DataFusion's code and we \
+ would welcome that you file an bug report in our issue tracker"
+ );
+
+ // multiple initializations
+ a1.initialize("a1", &pool);
+ a1.initialize("a2", &pool);
+ {
+ let locked = a1.0.lock();
+ let name = locked.as_ref().unwrap().consumer.name();
+ assert_eq!(name, "a1");
+ }
+
+ let a2 = a1.clone();
+
+ // Reserve from a1
+ a1.try_grow(10).unwrap();
+ assert_eq!(pool.reserved(), 10);
+
+ // Drop a1 - normally reservation calls `free` on drop.
+ // Ensure that reservation still alive in a2
+ drop(a1);
+ assert_eq!(pool.reserved(), 10);
+
+ // Ensure that after a2 dropped, memory gets back to the pool
+ drop(a2);
+ assert_eq!(pool.reserved(), 0);
+ }
}