This is an automated email from the ASF dual-hosted git repository.

dheres pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new c22f575  Clean up spawned task on drop for `RepartitionExec`, 
`SortPreservingMergeExec`, `WindowAggExec` (#1112)
c22f575 is described below

commit c22f5753fc7e3dfc121adcc7d39a74bc2199d308
Author: Marco Neumann <[email protected]>
AuthorDate: Thu Oct 14 18:06:42 2021 +0200

    Clean up spawned task on drop for `RepartitionExec`, 
`SortPreservingMergeExec`, `WindowAggExec` (#1112)
    
    * allow `BlockingExec` to mock multiple partitions
    
    * Clean up spawned task on `SortPreservingMergeExec` drop
    
    Ref #1103.
    
    * Clean up spawned task on `WindowAggExec` drop
    
    Ref #1103.
    
    * isolate common test code
    
    * Clean up spawned task on `RepartitionExec` drop
    
    Ref #1103.
    
    * rename variables / fix copy+paste mistake
    
    * introduce `RepartitionExecState` struct
    
    * introduce `AbortOnDrop{Single,Many}` helpers
---
 datafusion/src/physical_plan/common.rs             | 50 +++++++++++-
 datafusion/src/physical_plan/repartition.rs        | 94 ++++++++++++++++++----
 datafusion/src/physical_plan/sort.rs               | 35 ++------
 .../src/physical_plan/sort_preserving_merge.rs     | 86 +++++++++++++++-----
 datafusion/src/physical_plan/windows/mod.rs        | 37 ++++++++-
 .../src/physical_plan/windows/window_agg_exec.rs   | 25 +++---
 datafusion/src/test/exec.rs                        | 28 ++++++-
 datafusion/src/test/mod.rs                         | 11 +++
 8 files changed, 285 insertions(+), 81 deletions(-)

diff --git a/datafusion/src/physical_plan/common.rs 
b/datafusion/src/physical_plan/common.rs
index 3be9e72..d6a37e0 100644
--- a/datafusion/src/physical_plan/common.rs
+++ b/datafusion/src/physical_plan/common.rs
@@ -26,7 +26,8 @@ use arrow::error::ArrowError;
 use arrow::error::Result as ArrowResult;
 use arrow::record_batch::RecordBatch;
 use futures::channel::mpsc;
-use futures::{SinkExt, Stream, StreamExt, TryStreamExt};
+use futures::{Future, SinkExt, Stream, StreamExt, TryStreamExt};
+use pin_project_lite::pin_project;
 use std::fs;
 use std::fs::metadata;
 use std::sync::Arc;
@@ -225,6 +226,53 @@ pub fn compute_record_batch_statistics(
     }
 }
 
+pin_project! {
+    /// Helper that aborts the given join handle on drop.
+    ///
+    /// Useful to kill background tasks when the consumer is dropped.
+    #[derive(Debug)]
+    pub struct AbortOnDropSingle<T>{
+        #[pin]
+        join_handle: JoinHandle<T>,
+    }
+
+    impl<T> PinnedDrop for AbortOnDropSingle<T> {
+        fn drop(this: Pin<&mut Self>) {
+            this.join_handle.abort();
+        }
+    }
+}
+
+impl<T> AbortOnDropSingle<T> {
+    /// Create new abort helper from join handle.
+    pub fn new(join_handle: JoinHandle<T>) -> Self {
+        Self { join_handle }
+    }
+}
+
+impl<T> Future for AbortOnDropSingle<T> {
+    type Output = std::result::Result<T, tokio::task::JoinError>;
+
+    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> 
Poll<Self::Output> {
+        let this = self.project();
+        this.join_handle.poll(cx)
+    }
+}
+
+/// Helper that aborts the given join handles on drop.
+///
+/// Useful to kill background tasks when the consumer is dropped.
+#[derive(Debug)]
+pub struct AbortOnDropMany<T>(pub Vec<JoinHandle<T>>);
+
+impl<T> Drop for AbortOnDropMany<T> {
+    fn drop(&mut self) {
+        for join_handle in &self.0 {
+            join_handle.abort();
+        }
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
diff --git a/datafusion/src/physical_plan/repartition.rs 
b/datafusion/src/physical_plan/repartition.rs
index 56de364..74696ea 100644
--- a/datafusion/src/physical_plan/repartition.rs
+++ b/datafusion/src/physical_plan/repartition.rs
@@ -31,6 +31,7 @@ use arrow::{array::Array, error::Result as ArrowResult};
 use arrow::{compute::take, datatypes::SchemaRef};
 use tokio_stream::wrappers::UnboundedReceiverStream;
 
+use super::common::{AbortOnDropMany, AbortOnDropSingle};
 use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
 use super::{RecordBatchStream, SendableRecordBatchStream};
 use async_trait::async_trait;
@@ -46,21 +47,30 @@ use tokio::task::JoinHandle;
 
 type MaybeBatch = Option<ArrowResult<RecordBatch>>;
 
+/// Inner state of [`RepartitionExec`].
+#[derive(Debug)]
+struct RepartitionExecState {
+    /// Channels for sending batches from input partitions to output 
partitions.
+    /// Key is the partition number.
+    channels:
+        HashMap<usize, (UnboundedSender<MaybeBatch>, 
UnboundedReceiver<MaybeBatch>)>,
+
+    /// Helper that ensures that that background job is killed once it is no 
longer needed.
+    abort_helper: Arc<AbortOnDropMany<()>>,
+}
+
 /// The repartition operator maps N input partitions to M output partitions 
based on a
 /// partitioning scheme. No guarantees are made about the order of the 
resulting partitions.
 #[derive(Debug)]
 pub struct RepartitionExec {
     /// Input execution plan
     input: Arc<dyn ExecutionPlan>,
+
     /// Partitioning scheme to use
     partitioning: Partitioning,
-    /// Channels for sending batches from input partitions to output 
partitions.
-    /// Key is the partition number
-    channels: Arc<
-        Mutex<
-            HashMap<usize, (UnboundedSender<MaybeBatch>, 
UnboundedReceiver<MaybeBatch>)>,
-        >,
-    >,
+
+    /// Inner state that is initialized when the first output stream is 
created.
+    state: Arc<Mutex<RepartitionExecState>>,
 
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
@@ -156,13 +166,13 @@ impl ExecutionPlan for RepartitionExec {
 
     async fn execute(&self, partition: usize) -> 
Result<SendableRecordBatchStream> {
         // lock mutexes
-        let mut channels = self.channels.lock().await;
+        let mut state = self.state.lock().await;
 
         let num_input_partitions = 
self.input.output_partitioning().partition_count();
         let num_output_partitions = self.partitioning.partition_count();
 
         // if this is the first partition to be invoked then we need to set up 
initial state
-        if channels.is_empty() {
+        if state.channels.is_empty() {
             // create one channel per *output* partition
             for partition in 0..num_output_partitions {
                 // Note that this operator uses unbounded channels to avoid 
deadlocks because
@@ -173,14 +183,16 @@ impl ExecutionPlan for RepartitionExec {
                 // for this would be to add spill-to-disk capabilities.
                 let (sender, receiver) =
                     
mpsc::unbounded_channel::<Option<ArrowResult<RecordBatch>>>();
-                channels.insert(partition, (sender, receiver));
+                state.channels.insert(partition, (sender, receiver));
             }
             // Use fixed random state
             let random = ahash::RandomState::with_seeds(0, 0, 0, 0);
 
             // launch one async task per *input* partition
+            let mut join_handles = Vec::with_capacity(num_input_partitions);
             for i in 0..num_input_partitions {
-                let txs: HashMap<_, _> = channels
+                let txs: HashMap<_, _> = state
+                    .channels
                     .iter()
                     .map(|(partition, (tx, _rx))| (*partition, tx.clone()))
                     .collect();
@@ -199,8 +211,14 @@ impl ExecutionPlan for RepartitionExec {
 
                 // In a separate task, wait for each input to be done
                 // (and pass along any errors, including panic!s)
-                tokio::spawn(Self::wait_for_task(input_task, txs));
+                let join_handle = tokio::spawn(Self::wait_for_task(
+                    AbortOnDropSingle::new(input_task),
+                    txs,
+                ));
+                join_handles.push(join_handle);
             }
+
+            state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
         }
 
         // now return stream for the specified *output* partition which will
@@ -209,7 +227,10 @@ impl ExecutionPlan for RepartitionExec {
             num_input_partitions,
             num_input_partitions_processed: 0,
             schema: self.input.schema(),
-            input: 
UnboundedReceiverStream::new(channels.remove(&partition).unwrap().1),
+            input: UnboundedReceiverStream::new(
+                state.channels.remove(&partition).unwrap().1,
+            ),
+            drop_helper: Arc::clone(&state.abort_helper),
         }))
     }
 
@@ -243,7 +264,10 @@ impl RepartitionExec {
         Ok(RepartitionExec {
             input,
             partitioning,
-            channels: Arc::new(Mutex::new(HashMap::new())),
+            state: Arc::new(Mutex::new(RepartitionExecState {
+                channels: HashMap::new(),
+                abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
+            })),
             metrics: ExecutionPlanMetricsSet::new(),
         })
     }
@@ -372,7 +396,7 @@ impl RepartitionExec {
     /// complete. Upon error, propagates the errors to all output tx
     /// channels.
     async fn wait_for_task(
-        input_task: JoinHandle<Result<()>>,
+        input_task: AbortOnDropSingle<Result<()>>,
         txs: HashMap<usize, UnboundedSender<Option<ArrowResult<RecordBatch>>>>,
     ) {
         // wait for completion, and propagate error
@@ -409,12 +433,19 @@ impl RepartitionExec {
 struct RepartitionStream {
     /// Number of input partitions that will be sending batches to this output 
channel
     num_input_partitions: usize,
+
     /// Number of input partitions that have finished sending batches to this 
output channel
     num_input_partitions_processed: usize,
+
     /// Schema
     schema: SchemaRef,
+
     /// channel containing the repartitioned batches
     input: UnboundedReceiverStream<Option<ArrowResult<RecordBatch>>>,
+
+    /// Handle to ensure background tasks are killed when no longer needed.
+    #[allow(dead_code)]
+    drop_helper: Arc<AbortOnDropMany<()>>,
 }
 
 impl Stream for RepartitionStream {
@@ -454,8 +485,14 @@ mod tests {
     use super::*;
     use crate::{
         assert_batches_sorted_eq,
-        physical_plan::{expressions::col, memory::MemoryExec},
-        test::exec::{BarrierExec, ErrorExec, MockExec},
+        physical_plan::{collect, expressions::col, memory::MemoryExec},
+        test::{
+            assert_is_pending,
+            exec::{
+                assert_strong_count_converges_to_zero, BarrierExec, 
BlockingExec,
+                ErrorExec, MockExec,
+            },
+        },
     };
     use arrow::datatypes::{DataType, Field, Schema};
     use arrow::record_batch::RecordBatch;
@@ -463,6 +500,7 @@ mod tests {
         array::{ArrayRef, StringArray, UInt32Array},
         error::ArrowError,
     };
+    use futures::FutureExt;
 
     #[tokio::test]
     async fn one_to_many_round_robin() -> Result<()> {
@@ -853,4 +891,26 @@ mod tests {
         let schema = batch1.schema();
         BarrierExec::new(vec![vec![batch1, batch2], vec![batch3, batch4]], 
schema)
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
2));
+        let refs = blocking_exec.refs();
+        let repartition_exec = Arc::new(RepartitionExec::try_new(
+            blocking_exec,
+            Partitioning::UnknownPartitioning(1),
+        )?);
+
+        let fut = collect(repartition_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/sort.rs 
b/datafusion/src/physical_plan/sort.rs
index 68a4258..499d1f7 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -17,6 +17,7 @@
 
 //! Defines the SORT plan
 
+use super::common::AbortOnDropSingle;
 use super::metrics::{
     BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
 };
@@ -40,7 +41,6 @@ use std::any::Any;
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
-use tokio::task::JoinHandle;
 
 /// Sort execution plan
 #[derive(Debug)]
@@ -229,13 +229,7 @@ pin_project! {
         output: 
futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
         finished: bool,
         schema: SchemaRef,
-        join_handle: JoinHandle<()>,
-    }
-
-    impl PinnedDrop for SortStream {
-        fn drop(this: Pin<&mut Self>) {
-            this.join_handle.abort();
-        }
+        drop_helper: AbortOnDropSingle<()>,
     }
 }
 
@@ -273,7 +267,7 @@ impl SortStream {
             output: rx,
             finished: false,
             schema,
-            join_handle,
+            drop_helper: AbortOnDropSingle::new(join_handle),
         }
     }
 }
@@ -315,14 +309,14 @@ impl RecordBatchStream for SortStream {
 
 #[cfg(test)]
 mod tests {
-    use std::sync::Weak;
-
     use super::*;
     use crate::datasource::object_store::local::LocalFileSystem;
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
     use crate::physical_plan::expressions::col;
     use crate::physical_plan::memory::MemoryExec;
     use crate::physical_plan::{collect, file_format::CsvExec};
+    use crate::test::assert_is_pending;
+    use crate::test::exec::assert_strong_count_converges_to_zero;
     use crate::test::{self, aggr_test_schema, exec::BlockingExec};
     use arrow::array::*;
     use arrow::datatypes::*;
@@ -497,7 +491,7 @@ mod tests {
         let schema =
             Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
 
-        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema)));
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
         let refs = blocking_exec.refs();
         let sort_exec = Arc::new(SortExec::try_new(
             vec![PhysicalSortExpr {
@@ -510,22 +504,9 @@ mod tests {
         let fut = collect(sort_exec);
         let mut fut = fut.boxed();
 
-        let waker = futures::task::noop_waker();
-        let mut cx = futures::task::Context::from_waker(&waker);
-        let poll = fut.poll_unpin(&mut cx);
-
-        assert!(poll.is_pending());
+        assert_is_pending(&mut fut);
         drop(fut);
-        tokio::time::timeout(std::time::Duration::from_secs(10), async {
-            loop {
-                if dbg!(Weak::strong_count(&refs)) == 0 {
-                    break;
-                }
-                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
-            }
-        })
-        .await
-        .unwrap();
+        assert_strong_count_converges_to_zero(refs).await;
 
         Ok(())
     }
diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs 
b/datafusion/src/physical_plan/sort_preserving_merge.rs
index f65facc..0b75f46 100644
--- a/datafusion/src/physical_plan/sort_preserving_merge.rs
+++ b/datafusion/src/physical_plan/sort_preserving_merge.rs
@@ -17,6 +17,7 @@
 
 //! Defines the sort preserving merge plan
 
+use super::common::AbortOnDropMany;
 use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
 use std::any::Any;
 use std::cmp::Ordering;
@@ -151,17 +152,19 @@ impl ExecutionPlan for SortPreservingMergeExec {
                 self.input.execute(0).await
             }
             _ => {
-                let streams = (0..input_partitions)
+                let (receivers, join_handles) = (0..input_partitions)
                     .into_iter()
                     .map(|part_i| {
                         let (sender, receiver) = mpsc::channel(1);
-                        spawn_execution(self.input.clone(), sender, part_i);
-                        receiver
+                        let join_handle =
+                            spawn_execution(self.input.clone(), sender, 
part_i);
+                        (receiver, join_handle)
                     })
-                    .collect();
+                    .unzip();
 
                 Ok(Box::pin(SortPreservingMergeStream::new(
-                    streams,
+                    receivers,
+                    AbortOnDropMany(join_handles),
                     self.schema(),
                     &self.expr,
                     self.target_batch_size,
@@ -338,23 +341,34 @@ struct RowIndex {
 struct SortPreservingMergeStream {
     /// The schema of the RecordBatches yielded by this stream
     schema: SchemaRef,
+
     /// The sorted input streams to merge together
-    streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+    receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+
+    /// Drop helper for tasks feeding the [`receivers`](Self::receivers)
+    drop_helper: AbortOnDropMany<()>,
+
     /// For each input stream maintain a dequeue of SortKeyCursor
     ///
     /// Exhausted cursors will be popped off the front once all
     /// their rows have been yielded to the output
     cursors: Vec<VecDeque<SortKeyCursor>>,
+
     /// The accumulated row indexes for the next record batch
     in_progress: Vec<RowIndex>,
+
     /// The physical expressions to sort by
     column_expressions: Vec<Arc<dyn PhysicalExpr>>,
+
     /// The sort options for each expression
     sort_options: Vec<SortOptions>,
+
     /// The desired RecordBatch size to yield
     target_batch_size: usize,
+
     /// used to record execution metrics
     baseline_metrics: BaselineMetrics,
+
     /// If the stream has encountered an error
     aborted: bool,
 
@@ -364,13 +378,14 @@ struct SortPreservingMergeStream {
 
 impl SortPreservingMergeStream {
     fn new(
-        streams: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+        receivers: Vec<mpsc::Receiver<ArrowResult<RecordBatch>>>,
+        drop_helper: AbortOnDropMany<()>,
         schema: SchemaRef,
         expressions: &[PhysicalSortExpr],
         target_batch_size: usize,
         baseline_metrics: BaselineMetrics,
     ) -> Self {
-        let cursors = (0..streams.len())
+        let cursors = (0..receivers.len())
             .into_iter()
             .map(|_| VecDeque::new())
             .collect();
@@ -378,7 +393,8 @@ impl SortPreservingMergeStream {
         Self {
             schema,
             cursors,
-            streams,
+            receivers,
+            drop_helper,
             column_expressions: expressions.iter().map(|x| 
x.expr.clone()).collect(),
             sort_options: expressions.iter().map(|x| x.options).collect(),
             target_batch_size,
@@ -404,7 +420,7 @@ impl SortPreservingMergeStream {
             }
         }
 
-        let stream = &mut self.streams[idx];
+        let stream = &mut self.receivers[idx];
         if stream.is_terminated() {
             return Poll::Ready(Ok(()));
         }
@@ -644,6 +660,7 @@ impl RecordBatchStream for SortPreservingMergeStream {
 mod tests {
     use crate::datasource::object_store::local::LocalFileSystem;
     use crate::physical_plan::metrics::MetricValue;
+    use crate::test::exec::{assert_strong_count_converges_to_zero, 
BlockingExec};
     use std::iter::FromIterator;
 
     use crate::arrow::array::{Int32Array, StringArray, 
TimestampNanosecondArray};
@@ -654,10 +671,11 @@ mod tests {
     use crate::physical_plan::memory::MemoryExec;
     use crate::physical_plan::sort::SortExec;
     use crate::physical_plan::{collect, common};
-    use crate::test;
+    use crate::test::{self, assert_is_pending};
 
     use super::*;
-    use futures::SinkExt;
+    use arrow::datatypes::{DataType, Field, Schema};
+    use futures::{FutureExt, SinkExt};
     use tokio_stream::StreamExt;
 
     #[tokio::test]
@@ -1172,28 +1190,30 @@ mod tests {
         let batches = sorted_partitioned_input(sort.clone(), &[5, 7, 3]).await;
 
         let partition_count = batches.output_partitioning().partition_count();
-        let mut tasks = Vec::with_capacity(partition_count);
-        let mut streams = Vec::with_capacity(partition_count);
+        let mut join_handles = Vec::with_capacity(partition_count);
+        let mut receivers = Vec::with_capacity(partition_count);
 
         for partition in 0..partition_count {
             let (mut sender, receiver) = mpsc::channel(1);
             let mut stream = batches.execute(partition).await.unwrap();
-            let task = tokio::spawn(async move {
+            let join_handle = tokio::spawn(async move {
                 while let Some(batch) = stream.next().await {
                     sender.send(batch).await.unwrap();
                     // This causes the MergeStream to wait for more input
                     
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
                 }
             });
-            tasks.push(task);
-            streams.push(receiver);
+            join_handles.push(join_handle);
+            receivers.push(receiver);
         }
 
         let metrics = ExecutionPlanMetricsSet::new();
         let baseline_metrics = BaselineMetrics::new(&metrics, 0);
 
         let merge_stream = SortPreservingMergeStream::new(
-            streams,
+            receivers,
+            // Use empty vector since we want to use the join handles ourselves
+            AbortOnDropMany(vec![]),
             batches.schema(),
             sort.as_slice(),
             1024,
@@ -1203,8 +1223,8 @@ mod tests {
         let mut merged = 
common::collect(Box::pin(merge_stream)).await.unwrap();
 
         // Propagate any errors
-        for task in tasks {
-            task.await.unwrap();
+        for join_handle in join_handles {
+            join_handle.await.unwrap();
         }
 
         assert_eq!(merged.len(), 1);
@@ -1271,4 +1291,30 @@ mod tests {
         assert!(saw_start);
         assert!(saw_end);
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
2));
+        let refs = blocking_exec.refs();
+        let sort_preserving_merge_exec = Arc::new(SortPreservingMergeExec::new(
+            vec![PhysicalSortExpr {
+                expr: col("a", &schema)?,
+                options: SortOptions::default(),
+            }],
+            blocking_exec,
+            1,
+        ));
+
+        let fut = collect(sort_preserving_merge_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/windows/mod.rs 
b/datafusion/src/physical_plan/windows/mod.rs
index 3aa67cf..ef420b2 100644
--- a/datafusion/src/physical_plan/windows/mod.rs
+++ b/datafusion/src/physical_plan/windows/mod.rs
@@ -180,10 +180,12 @@ mod tests {
     use crate::physical_plan::expressions::col;
     use crate::physical_plan::file_format::CsvExec;
     use crate::physical_plan::{collect, Statistics};
-    use crate::test::{self, aggr_test_schema};
+    use crate::test::exec::{assert_strong_count_converges_to_zero, 
BlockingExec};
+    use crate::test::{self, aggr_test_schema, assert_is_pending};
     use arrow::array::*;
-    use arrow::datatypes::SchemaRef;
+    use arrow::datatypes::{DataType, Field, SchemaRef};
     use arrow::record_batch::RecordBatch;
+    use futures::FutureExt;
 
     fn create_test_schema(partitions: usize) -> Result<(Arc<CsvExec>, 
SchemaRef)> {
         let schema = test::aggr_test_schema();
@@ -264,4 +266,35 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
+        let refs = blocking_exec.refs();
+        let window_agg_exec = Arc::new(WindowAggExec::try_new(
+            vec![create_window_expr(
+                &WindowFunction::AggregateFunction(AggregateFunction::Count),
+                "count".to_owned(),
+                &[col("a", &schema)?],
+                &[],
+                &[],
+                Some(WindowFrame::default()),
+                schema.as_ref(),
+            )?],
+            blocking_exec,
+            schema,
+        )?);
+
+        let fut = collect(window_agg_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs 
b/datafusion/src/physical_plan/windows/window_agg_exec.rs
index 03307be..228b53f 100644
--- a/datafusion/src/physical_plan/windows/window_agg_exec.rs
+++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs
@@ -18,6 +18,7 @@
 //! Stream and channel implementations for window function expressions.
 
 use crate::error::{DataFusionError, Result};
+use crate::physical_plan::common::AbortOnDropSingle;
 use crate::physical_plan::metrics::{
     BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
 };
@@ -219,14 +220,15 @@ fn compute_window_aggregates(
 }
 
 pin_project! {
-  /// stream for window aggregation plan
-  pub struct WindowAggStream {
-      schema: SchemaRef,
-      #[pin]
-      output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
-      finished: bool,
-      baseline_metrics: BaselineMetrics,
-  }
+    /// stream for window aggregation plan
+    pub struct WindowAggStream {
+        schema: SchemaRef,
+        drop_helper: AbortOnDropSingle<()>,
+        #[pin]
+        output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
+        finished: bool,
+        baseline_metrics: BaselineMetrics,
+    }
 }
 
 impl WindowAggStream {
@@ -240,16 +242,19 @@ impl WindowAggStream {
         let (tx, rx) = futures::channel::oneshot::channel();
         let schema_clone = schema.clone();
         let elapsed_compute = baseline_metrics.elapsed_compute().clone();
-        tokio::spawn(async move {
+        let join_handle = tokio::spawn(async move {
             let schema = schema_clone.clone();
             let result =
                 WindowAggStream::process(input, window_expr, schema, 
elapsed_compute)
                     .await;
-            tx.send(result)
+
+            // failing here is OK, the receiver is gone and does not care 
about the result
+            tx.send(result).ok();
         });
 
         Self {
             schema,
+            drop_helper: AbortOnDropSingle::new(join_handle),
             output: rx,
             finished: false,
             baseline_metrics,
diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs
index 252168f..aca3b6e 100644
--- a/datafusion/src/test/exec.rs
+++ b/datafusion/src/test/exec.rs
@@ -482,15 +482,19 @@ pub struct BlockingExec {
     /// Schema that is mocked by this plan.
     schema: SchemaRef,
 
+    /// Number of output partitions.
+    n_partitions: usize,
+
     /// Ref-counting helper to check if the plan and the produced stream are 
still in memory.
     refs: Arc<()>,
 }
 
 impl BlockingExec {
-    /// Create new [`BlockingExec`] with a give schema.
-    pub fn new(schema: SchemaRef) -> Self {
+    /// Create new [`BlockingExec`] with a give schema and number of 
partitions.
+    pub fn new(schema: SchemaRef, n_partitions: usize) -> Self {
         Self {
             schema,
+            n_partitions,
             refs: Default::default(),
         }
     }
@@ -499,7 +503,7 @@ impl BlockingExec {
     ///
     /// Use [`Weak::strong_count`] to determine if the plan itself and its 
streams are dropped (should be 0 in that
     /// case). Note that tokio might take some time to cancel spawned tasks, 
so you need to wrap this check into a retry
-    /// loop.
+    /// loop. Use [`assert_strong_count_converges_to_zero`] to archive this.
     pub fn refs(&self) -> Weak<()> {
         Arc::downgrade(&self.refs)
     }
@@ -521,7 +525,7 @@ impl ExecutionPlan for BlockingExec {
     }
 
     fn output_partitioning(&self) -> Partitioning {
-        Partitioning::UnknownPartitioning(1)
+        Partitioning::UnknownPartitioning(self.n_partitions)
     }
 
     fn with_new_children(
@@ -584,3 +588,19 @@ impl RecordBatchStream for BlockingStream {
         Arc::clone(&self.schema)
     }
 }
+
+/// Asserts that the strong count of the given [`Weak`] pointer converges to 
zero.
+///
+/// This might take a while but has a timeout.
+pub async fn assert_strong_count_converges_to_zero<T>(refs: Weak<T>) {
+    tokio::time::timeout(std::time::Duration::from_secs(10), async {
+        loop {
+            if dbg!(Weak::strong_count(&refs)) == 0 {
+                break;
+            }
+            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+        }
+    })
+    .await
+    .unwrap();
+}
diff --git a/datafusion/src/test/mod.rs b/datafusion/src/test/mod.rs
index 917e7b1..c6eae35 100644
--- a/datafusion/src/test/mod.rs
+++ b/datafusion/src/test/mod.rs
@@ -28,9 +28,11 @@ use array::{
 use arrow::array::{self, Int32Array};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
+use futures::{Future, FutureExt};
 use std::fs::File;
 use std::io::prelude::*;
 use std::io::{BufReader, BufWriter};
+use std::pin::Pin;
 use std::sync::Arc;
 use tempfile::TempDir;
 
@@ -291,6 +293,15 @@ pub fn make_timestamps() -> RecordBatch {
     .unwrap()
 }
 
+/// Asserts that given future is pending.
+pub fn assert_is_pending<'a, T>(fut: &mut Pin<Box<dyn Future<Output = T> + 
Send + 'a>>) {
+    let waker = futures::task::noop_waker();
+    let mut cx = futures::task::Context::from_waker(&waker);
+    let poll = fut.poll_unpin(&mut cx);
+
+    assert!(poll.is_pending());
+}
+
 pub mod exec;
 pub mod user_defined;
 pub mod variable;

Reply via email to