This is an automated email from the ASF dual-hosted git repository.
dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new c22f575 Clean up spawned task on drop for `RepartitionExec`,
`SortPreservingMergeExec`, `WindowAggExec` (#1112)
c22f575 is described below
commit c22f5753fc7e3dfc121adcc7d39a74bc2199d308
Author: Marco Neumann <[email protected]>
AuthorDate: Thu Oct 14 18:06:42 2021 +0200
Clean up spawned task on drop for `RepartitionExec`,
`SortPreservingMergeExec`, `WindowAggExec` (#1112)
* allow `BlockingExec` to mock multiple partitions
* Clean up spawned task on `SortPreservingMergeExec` drop
Ref #1103.
* Clean up spawned task on `WindowAggExec` drop
Ref #1103.
* isolate common test code
* Clean up spawned task on `RepartitionExec` drop
Ref #1103.
* rename variables / fix copy+paste mistake
* introduce `RepartitionExecState` struct
* introduce `AbortOnDrop{Single,Many}` helpers
---
datafusion/src/physical_plan/common.rs | 50 +++++++++++-
datafusion/src/physical_plan/repartition.rs | 94 ++++++++++++++++++----
datafusion/src/physical_plan/sort.rs | 35 ++------
.../src/physical_plan/sort_preserving_merge.rs | 86 +++++++++++++++-----
datafusion/src/physical_plan/windows/mod.rs | 37 ++++++++-
.../src/physical_plan/windows/window_agg_exec.rs | 25 +++---
datafusion/src/test/exec.rs | 28 ++++++-
datafusion/src/test/mod.rs | 11 +++
8 files changed, 285 insertions(+), 81 deletions(-)
diff --git a/datafusion/src/physical_plan/common.rs
b/datafusion/src/physical_plan/common.rs
index 3be9e72..d6a37e0 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -26,7 +26,8 @@ use arrow::error::ArrowError;
use arrow::error::Result as ArrowResult;
use arrow::record_batch::RecordBatch;
use futures::channel::mpsc;
-use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
+use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt};
+use pin_project_lite::pin_project;
use std::fs;
use std::fs::metadata;
use std::sync::Arc;
@@ -225,6 +226,53 @@ pub fn compute_record_batch_statistics(
}
}
+pin_project! {
+ /// Helper that aborts the given join handle on drop.
+ ///
+ /// Useful to kill background tasks when the consumer is dropped.
+ #[derive(Debug)]
+ pub struct AbortOnDropSingle<T>{
+ #[pin]
+ join_handle: JoinHandle<T>,
+ }
+
+ impl<T> PinnedDrop for AbortOnDropSingle<T> {
+ fn drop(this: Pin<&mut Self>) {
+ this.join_handle.abort();
+ }
+ }
+}
+
+impl<T> AbortOnDropSingle<T> {
+ /// Create new abort helper from join handle.
+ pub fn new(join_handle: JoinHandle<T>) -> Self {
+ Self { join_handle }
+ }
+}
+
+impl<T> Future for AbortOnDropSingle<T> {
+ type Output = std::result::Result<T, tokio::task::JoinError>;
+
+ fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) ->
Poll<Self::Output> {
+ let this = self.project();
+ this.join_handle.poll(cx)
+ }
+}
+
+/// Helper that aborts the given join handles on drop.
+///
+/// Useful to kill background tasks when the consumer is dropped.
+#[derive(Debug)]
+pub struct AbortOnDropMany<T>(pub Vec<JoinHandle<T>>);
+
+impl<T> Drop for AbortOnDropMany<T> {
+ fn drop(&mut self) {
+ for join_handle in &self.0 {
+ join_handle.abort();
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/datafusion/src/physical_plan/repartition.rs
b/datafusion/src/physical_plan/repartition.rs
index 56de364..74696ea 100644
--- a/datafusion/src/physical_plan/repartition.rs
+++ b/datafusion/src/physical_plan/repartition.rs
@@ -31,6 +31,7 @@ use arrow::{array::Array, error::Result as ArrowResult};
use arrow::{compute::take, datatypes::SchemaRef};
use tokio_stream::wrappers::UnboundedReceiverStream;
+use super::common::{AbortOnDropMany, AbortOnDropSingle};
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{RecordBatchStream, SendableRecordBatchStream};
use async_trait::async_trait;
@@ -46,21 +47,30 @@ use tokio::task::JoinHandle;
type MaybeBatch = Option<ArrowResult<RecordBatch>>;
+/// Inner state of [`RepartitionExec`].
+#[derive(Debug)]
+struct RepartitionExecState {
+ /// Channels for sending batches from input partitions to output
partitions.
+ /// Key is the partition number.
+ channels:
+ HashMap<usize, (UnboundedSender<MaybeBatch>,
UnboundedReceiver<MaybeBatch>)>,
+
+ /// Helper that ensures that that background job is killed once it is no
longer needed.
+ abort_helper: Arc<AbortOnDropMany<()>>,
+}
+
/// The repartition operator maps N input partitions to M output partitions
based on a
/// partitioning scheme. No guarantees are made about the order of the
resulting partitions.
#[derive(Debug)]
pub struct RepartitionExec {
/// Input execution plan
input: Arc<dyn ExecutionPlan>,
+
/// Partitioning scheme to use
partitioning: Partitioning,
- /// Channels for sending batches from input partitions to output
partitions.
- /// Key is the partition number
- channels: Arc<
- Mutex<
- HashMap<usize, (UnboundedSender<MaybeBatch>,
UnboundedReceiver<MaybeBatch>)>,
- >,
- >,
+
+ /// Inner state that is initialized when the first output stream is
created.
+ state: Arc<Mutex<RepartitionExecState>>,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
@@ -156,13 +166,13 @@ impl ExecutionPlan for RepartitionExec {
async fn execute(&self, partition: usize) ->
Result<SendableRecordBatchStream> {
// lock mutexes
- let mut channels = self.channels.lock().await;
+ let mut state = self.state.lock().await;
let num_input_partitions =
self.input.output_partitioning().partition_count();
let num_output_partitions = self.partitioning.partition_count();
// if this is the first partition to be invoked then we need to set up
initial state
- if channels.is_empty() {
+ if state.channels.is_empty() {
// create one channel per *output* partition
for partition in 0..num_output_partitions {
// Note that this operator uses unbounded channels to avoid
deadlocks because
@@ -173,14 +183,16 @@ impl ExecutionPlan for RepartitionExec {
// for this would be to add spill-to-disk capabilities.
let (sender, receiver) =
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
- channels.insert(partition, (sender, receiver));
+ state.channels.insert(partition, (sender, receiver));
}
// Use fixed random state
let random = ahash::RandomState::with_seeds(0, 0, 0, 0);
// launch one async task per *input* partition
+ let mut join_handles = Vec::with_capacity(num_input_partitions);
for i in 0..num_input_partitions {
- let txs: HashMap<_, _> = channels
+ let txs: HashMap<_, _> = state
+ .channels
.iter()
.map(|(partition, (tx, _rx))| (*partition, tx.clone()))
.collect();
@@ -199,8 +211,14 @@ impl ExecutionPlan for RepartitionExec {
// In a separate task, wait for each input to be done
// (and pass along any errors, including panic!s)
- tokio::spawn(Self::wait_for_task(input_task, txs));
+ let join_handle = tokio::spawn(Self::wait_for_task(
+ AbortOnDropSingle::new(input_task),
+ txs,
+ ));
+ join_handles.push(join_handle);
}
+
+ state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
}
// now return stream for the specified *output* partition which will
@@ -209,7 +227,10 @@ impl ExecutionPlan for RepartitionExec {
num_input_partitions,
num_input_partitions_processed: 0,
schema: self.input.schema(),
- input:
UnboundedReceiverStream::new(channels.remove(&partition).unwrap().1),
+ input: UnboundedReceiverStream::new(
+ state.channels.remove(&partition).unwrap().1,
+ ),
+ drop_helper: Arc::clone(&state.abort_helper),
}))
}
@@ -243,7 +264,10 @@ impl RepartitionExec {
Ok(RepartitionExec {
input,
partitioning,
- channels: Arc::new(Mutex::new(HashMap::new())),
+ state: Arc::new(Mutex::new(RepartitionExecState {
+ channels: HashMap::new(),
+ abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
+ })),
metrics: ExecutionPlanMetricsSet::new(),
})
}
@@ -372,7 +396,7 @@ impl RepartitionExec {
/// complete. Upon error, propagates the errors to all output tx
/// channels.
async fn wait_for_task(
- input_task: JoinHandle<Result<()>>,
+ input_task: AbortOnDropSingle<Result<()>>,
txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
) {
// wait for completion, and propagate error
@@ -409,12 +433,19 @@ impl RepartitionExec {
struct RepartitionStream {
/// Number of input partitions that will be sending batches to this output
channel
num_input_partitions: usize,
+
/// Number of input partitions that have finished sending batches to this
output channel
num_input_partitions_processed: usize,
+
/// Schema
schema: SchemaRef,
+
/// channel containing the repartitioned batches
input: UnboundedReceiverStream<Option<ArrowResult<RecordBatch>>>,
+
+ /// Handle to ensure background tasks are killed when no longer needed.
+ #[allow(dead_code)]
+ drop_helper: Arc<AbortOnDropMany<()>>,
}
impl Stream for RepartitionStream {
@@ -454,8 +485,14 @@ mod tests {
use super::*;
use crate::{
assert_batches_sorted_eq,
- physical_plan::{expressions::col, memory::MemoryExec},
- test::exec::{BarrierExec, ErrorExec, MockExec},
+ physical_plan::{collect, expressions::col, memory::MemoryExec},
+ test::{
+ assert_is_pending,
+ exec::{
+ assert_strong_count_converges_to_zero, BarrierExec,
BlockingExec,
+ ErrorExec, MockExec,
+ },
+ },
};
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
@@ -463,6 +500,7 @@ mod tests {
array::{ArrayRef, StringArray, UInt32Array},
error::ArrowError,
};
+ use futures::FutureExt;
#[tokio::test]
async fn one_to_many_round_robin() -> Result<()> {
@@ -853,4 +891,26 @@ mod tests {
let schema = batch1.schema();
BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]],
schema)
}
+
+ #[tokio::test]
+ async fn test_drop_cancel() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
2));
+ let refs = blocking_exec.refs();
+ let repartition_exec = Arc::new(RepartitionExec::try_new(
+ blocking_exec,
+ Partitioning::UnknownPartitioning(1),
+ )?);
+
+ let fut = collect(repartition_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/sort.rs
b/datafusion/src/physical_plan/sort.rs
index 68a4258..499d1f7 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -17,6 +17,7 @@
//! Defines the SORT plan
+use super::common::AbortOnDropSingle;
use super::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
@@ -40,7 +41,6 @@ use std::any::Any;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
-use tokio::task::JoinHandle;
/// Sort execution plan
#[derive(Debug)]
@@ -229,13 +229,7 @@ pin_project! {
output:
futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
finished: bool,
schema: SchemaRef,
- join_handle: JoinHandle<()>,
- }
-
- impl PinnedDrop for SortStream {
- fn drop(this: Pin<&mut Self>) {
- this.join_handle.abort();
- }
+ drop_helper: AbortOnDropSingle<()>,
}
}
@@ -273,7 +267,7 @@ impl SortStream {
output: rx,
finished: false,
schema,
- join_handle,
+ drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
@@ -315,14 +309,14 @@ impl RecordBatchStream for SortStream {
#[cfg(test)]
mod tests {
- use std::sync::Weak;
-
use super::*;
use crate::datasource::object_store::local::LocalFileSystem;
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
use crate::physical_plan::expressions::col;
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::{collect, file_format::CsvExec};
+ use crate::test::assert_is_pending;
+ use crate::test::exec::assert_strong_count_converges_to_zero;
use crate::test::{self, aggr_test_schema, exec::BlockingExec};
use arrow::array::*;
use arrow::datatypes::*;
@@ -497,7 +491,7 @@ mod tests {
let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
- let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema)));
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
let refs = blocking_exec.refs();
let sort_exec = Arc::new(SortExec::try_new(
vec
+ drop_helper: AbortOnDropMany<()>,
+
/// For each input stream maintain a dequeue of SortKeyCursor
///
/// Exhausted cursors will be popped off the front once all
/// their rows have been yielded to the output
cursors: Vec<VecDeque<SortKeyCursor>>,
+
/// The accumulated row indexes for the next record batch
in_progress: Vec<RowIndex>,
+
/// The physical expressions to sort by
column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+
/// The sort options for each expression
sort_options: Vec<SortOptions>,
+
/// The desired RecordBatch size to yield
target_batch_size: usize,
+
/// used to record execution metrics
baseline_metrics: BaselineMetrics,
+
/// If the stream has encountered an error
aborted: bool,
@@ -364,13 +378,14 @@ struct SortPreservingMergeStream {
impl SortPreservingMergeStream {
fn new(
- streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+ receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+ drop_helper: AbortOnDropMany<()>,
schema: SchemaRef,
expressions: &[PhysicalSortExpr],
target_batch_size: usize,
baseline_metrics: BaselineMetrics,
) -> Self {
- let cursors = (0..streams.len())
+ let cursors = (0..receivers.len())
.into_iter()
.map(|_| VecDeque::new())
.collect();
@@ -378,7 +393,8 @@ impl SortPreservingMergeStream {
Self {
schema,
cursors,
- streams,
+ receivers,
+ drop_helper,
column_expressions: expressions.iter().map(|x|
x.expr.clone()).collect(),
sort_options: expressions.iter().map(|x| x.options).collect(),
target_batch_size,
@@ -404,7 +420,7 @@ impl SortPreservingMergeStream {
}
}
- let stream = &mut self.streams[idx];
+ let stream = &mut self.receivers[idx];
if stream.is_terminated() {
return Poll::Ready(Ok(()));
}
@@ -644,6 +660,7 @@ impl RecordBatchStream for SortPreservingMergeStream {
mod tests {
use crate::datasource::object_store::local::LocalFileSystem;
use crate::physical_plan::metrics::MetricValue;
+ use crate::test::exec::{assert_strong_count_converges_to_zero,
BlockingExec};
use std::iter::FromIterator;
use crate::arrow::array::{Int32Array, StringArray,
TimestampNanosecondArray};
@@ -654,10 +671,11 @@ mod tests {
use crate::physical_plan::memory::MemoryExec;
use crate::physical_plan::sort::SortExec;
use crate::physical_plan::{collect, common};
- use crate::test;
+ use crate::test::{self, assert_is_pending};
use super::*;
- use futures::SinkExt;
+ use arrow::datatypes::{DataType, Field, Schema};
+ use futures::{FutureExt, SinkExt};
use tokio_stream::StreamExt;
#[tokio::test]
@@ -1172,28 +1190,30 @@ mod tests {
let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3]).await;
let partition_count = batches.output_partitioning().partition_count();
- let mut tasks = Vec::with_capacity(partition_count);
- let mut streams = Vec::with_capacity(partition_count);
+ let mut join_handles = Vec::with_capacity(partition_count);
+ let mut receivers = Vec::with_capacity(partition_count);
for partition in 0..partition_count {
let (mut sender, receiver) = mpsc::channel(1);
let mut stream = batches.execute(partition).await.unwrap();
- let task = tokio::spawn(async move {
+ let join_handle = tokio::spawn(async move {
while let Some(batch) = stream.next().await {
sender.send(batch).await.unwrap();
// This causes the MergeStream to wait for more input
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
});
- tasks.push(task);
- streams.push(receiver);
+ join_handles.push(join_handle);
+ receivers.push(receiver);
}
let metrics = ExecutionPlanMetricsSet::new();
let baseline_metrics = BaselineMetrics::new(&metrics, 0);
let merge_stream = SortPreservingMergeStream::new(
- streams,
+ receivers,
+ // Use empty vector since we want to use the join handles ourselves
+ AbortOnDropMany(vec![]),
batches.schema(),
sort.as_slice(),
1024,
@@ -1203,8 +1223,8 @@ mod tests {
let mut merged =
common::collect(Box::pin(merge_stream)).await.unwrap();
// Propagate any errors
- for task in tasks {
- task.await.unwrap();
+ for join_handle in join_handles {
+ join_handle.await.unwrap();
}
assert_eq!(merged.len(), 1);
@@ -1271,4 +1291,30 @@ mod tests {
assert!(saw_start);
assert!(saw_end);
}
+
+ #[tokio::test]
+ async fn test_drop_cancel() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
2));
+ let refs = blocking_exec.refs();
+ let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
+ vec![PhysicalSortExpr {
+ expr: col("a", &schema)?,
+ options: SortOptions::default(),
+ }],
+ blocking_exec,
+ 1,
+ ));
+
+ let fut = collect(sort_preserving_merge_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/windows/mod.rs
b/datafusion/src/physical_plan/windows/mod.rs
index 3aa67cf..ef420b2 100644
--- a/datafusion/src/physical_plan/windows/mod.rs
+++ b/datafusion/src/physical_plan/windows/mod.rs
@@ -180,10 +180,12 @@ mod tests {
use crate::physical_plan::expressions::col;
use crate::physical_plan::file_format::CsvExec;
use crate::physical_plan::{collect, Statistics};
- use crate::test::{self, aggr_test_schema};
+ use crate::test::exec::{assert_strong_count_converges_to_zero,
BlockingExec};
+ use crate::test::{self, aggr_test_schema, assert_is_pending};
use arrow::array::*;
- use arrow::datatypes::SchemaRef;
+ use arrow::datatypes::{DataType, Field, SchemaRef};
use arrow::record_batch::RecordBatch;
+ use futures::FutureExt;
fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>,
SchemaRef)> {
let schema = test::aggr_test_schema();
@@ -264,4 +266,35 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_drop_cancel() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
+ let refs = blocking_exec.refs();
+ let window_agg_exec = Arc::new(WindowAggExec::try_new(
+ vec![create_window_expr(
+ &WindowFunction::AggregateFunction(AggregateFunction::Count),
+ "count".to_owned(),
+ &[col("a", &schema)?],
+ &[],
+ &[],
+ Some(WindowFrame::default()),
+ schema.as_ref(),
+ )?],
+ blocking_exec,
+ schema,
+ )?);
+
+ let fut = collect(window_agg_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs
b/datafusion/src/physical_plan/windows/window_agg_exec.rs
index 03307be..228b53f 100644
--- a/datafusion/src/physical_plan/windows/window_agg_exec.rs
+++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs
@@ -18,6 +18,7 @@
//! Stream and channel implementations for window function expressions.
use crate::error::{DataFusionError, Result};
+use crate::physical_plan::common::AbortOnDropSingle;
use crate::physical_plan::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
};
@@ -219,14 +220,15 @@ fn compute_window_aggregates(
}
pin_project! {
- /// stream for window aggregation plan
- pub struct WindowAggStream {
- schema: SchemaRef,
- #[pin]
- output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
- finished: bool,
- baseline_metrics: BaselineMetrics,
- }
+ /// stream for window aggregation plan
+ pub struct WindowAggStream {
+ schema: SchemaRef,
+ drop_helper: AbortOnDropSingle<()>,
+ #[pin]
+ output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
+ finished: bool,
+ baseline_metrics: BaselineMetrics,
+ }
}
impl WindowAggStream {
@@ -240,16 +242,19 @@ impl WindowAggStream {
let (tx, rx) = futures::channel::oneshot::channel();
let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
- tokio::spawn(async move {
+ let join_handle = tokio::spawn(async move {
let schema = schema_clone.clone();
let result =
WindowAggStream::process(input, window_expr, schema,
elapsed_compute)
.await;
- tx.send(result)
+
+ // failing here is OK, the receiver is gone and does not care
about the result
+ tx.send(result).ok();
});
Self {
schema,
+ drop_helper: AbortOnDropSingle::new(join_handle),
output: rx,
finished: false,
baseline_metrics,
diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs
index 252168f..aca3b6e 100644
--- a/datafusion/src/test/exec.rs
+++ b/datafusion/src/test/exec.rs
@@ -482,15 +482,19 @@ pub struct BlockingExec {
/// Schema that is mocked by this plan.
schema: SchemaRef,
+ /// Number of output partitions.
+ n_partitions: usize,
+
/// Ref-counting helper to check if the plan and the produced stream are
still in memory.
refs: Arc<()>,
}
impl BlockingExec {
- /// Create new [`BlockingExec`] with a give schema.
- pub fn new(schema: SchemaRef) -> Self {
+ /// Create new [`BlockingExec`] with a give schema and number of
partitions.
+ pub fn new(schema: SchemaRef, n_partitions: usize) -> Self {
Self {
schema,
+ n_partitions,
refs: Default::default(),
}
}
@@ -499,7 +503,7 @@ impl BlockingExec {
///
/// Use [`Weak::strong_count`] to determine if the plan itself and its
streams are dropped (should be 0 in that
/// case). Note that tokio might take some time to cancel spawned tasks,
so you need to wrap this check into a retry
- /// loop.
+ /// loop. Use [`assert_strong_count_converges_to_zero`] to archive this.
pub fn refs(&self) -> Weak<()> {
Arc::downgrade(&self.refs)
}
@@ -521,7 +525,7 @@ impl ExecutionPlan for BlockingExec {
}
fn output_partitioning(&self) -> Partitioning {
- Partitioning::UnknownPartitioning(1)
+ Partitioning::UnknownPartitioning(self.n_partitions)
}
fn with_new_children(
@@ -584,3 +588,19 @@ impl RecordBatchStream for BlockingStream {
Arc::clone(&self.schema)
}
}
+
+/// Asserts that the strong count of the given [`Weak`] pointer converges to
zero.
+///
+/// This might take a while but has a timeout.
+pub async fn assert_strong_count_converges_to_zero<T>(refs: Weak<T>) {
+ tokio::time::timeout(std::time::Duration::from_secs(10), async {
+ loop {
+ if dbg!(Weak::strong_count(&refs)) == 0 {
+ break;
+ }
+ tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+ }
+ })
+ .await
+ .unwrap();
+}
diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs
index 917e7b1..c6eae35 100644
--- a/datafusion/src/test/mod.rs
+++ b/datafusion/src/test/mod.rs
@@ -28,9 +28,11 @@ use array::{
use arrow::array::{self, Int32Array};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
+use futures::{Future, FutureExt};
use std::fs::File;
use std::io::prelude::*;
use std::io::{BufReader, BufWriter};
+use std::pin::Pin;
use std::sync::Arc;
use tempfile::TempDir;
@@ -291,6 +293,15 @@ pub fn make_timestamps() -> RecordBatch {
.unwrap()
}
+/// Asserts that given future is pending.
+pub fn assert_is_pending<'a, T>(fut: &mut Pin<Box<dyn Future<Output = T> +
Send + 'a>>) {
+ let waker = futures::task::noop_waker();
+ let mut cx = futures::task::Context::from_waker(&waker);
+ let poll = fut.poll_unpin(&mut cx);
+
+ assert!(poll.is_pending());
+}
+
pub mod exec;
pub mod user_defined;
pub mod variable;