This is an automated email from the ASF dual-hosted git repository.

houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 831e07d  Clean up spawned task on drop for `AnalyzeExec`, 
`CoalescePartitionsExec`, `HashAggregateExec` (#1121)
831e07d is described below

commit 831e07debc4f136f2e47e126b20e441f7606bd74
Author: Marco Neumann <[email protected]>
AuthorDate: Sat Oct 16 06:48:17 2021 +0200

    Clean up spawned task on drop for `AnalyzeExec`, `CoalescePartitionsExec`, 
`HashAggregateExec` (#1121)
    
    * Clean up spawned task on `HashAggregateExec` drop
    
    Ref #1103.
    
    * Clean up spawned task on `CoalescePartitionsExec` drop
    
    Ref #1103.
    
    * Clean up spawned task on `AnalyzeExec` drop
    
    As a side-effect, cancelation now works with all users of
    `RecordBatchReceiverStream` (e.g. `ParquetExec`) but there the effect
    should be slightly less important.
    
    Ref #1103.
---
 datafusion/src/physical_plan/analyze.rs            | 43 ++++++++++-
 .../src/physical_plan/coalesce_partitions.rs       | 40 +++++++++-
 .../src/physical_plan/file_format/parquet.rs       |  8 +-
 datafusion/src/physical_plan/hash_aggregate.rs     | 89 +++++++++++++++++++++-
 datafusion/src/physical_plan/stream.rs             | 13 +++-
 datafusion/src/test/exec.rs                        | 16 +++-
 6 files changed, 192 insertions(+), 17 deletions(-)

diff --git a/datafusion/src/physical_plan/analyze.rs 
b/datafusion/src/physical_plan/analyze.rs
index e68acc5..c9e316e 100644
--- a/datafusion/src/physical_plan/analyze.rs
+++ b/datafusion/src/physical_plan/analyze.rs
@@ -125,7 +125,7 @@ impl ExecutionPlan for AnalyzeExec {
 
         // Task reads batches the input and when complete produce a
         // RecordBatch with a report that is written to `tx` when done
-        tokio::task::spawn(async move {
+        let join_handle = tokio::task::spawn(async move {
             let start = Instant::now();
             let mut total_rows = 0;
 
@@ -194,7 +194,11 @@ impl ExecutionPlan for AnalyzeExec {
             tx.send(maybe_batch).await.ok();
         });
 
-        Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+        Ok(RecordBatchReceiverStream::create(
+            &self.schema,
+            rx,
+            join_handle,
+        ))
     }
 
     fn fmt_as(
@@ -214,3 +218,38 @@ impl ExecutionPlan for AnalyzeExec {
         Statistics::default()
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use arrow::datatypes::{DataType, Field, Schema};
+    use futures::FutureExt;
+
+    use crate::{
+        physical_plan::collect,
+        test::{
+            assert_is_pending,
+            exec::{assert_strong_count_converges_to_zero, BlockingExec},
+        },
+    };
+
+    use super::*;
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
+        let refs = blocking_exec.refs();
+        let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, 
schema));
+
+        let fut = collect(analyze_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
+}
diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs 
b/datafusion/src/physical_plan/coalesce_partitions.rs
index a106838..1fd18d2 100644
--- a/datafusion/src/physical_plan/coalesce_partitions.rs
+++ b/datafusion/src/physical_plan/coalesce_partitions.rs
@@ -30,6 +30,7 @@ use async_trait::async_trait;
 use arrow::record_batch::RecordBatch;
 use arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
 
+use super::common::AbortOnDropMany;
 use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
 use super::{RecordBatchStream, Statistics};
 use crate::error::{DataFusionError, Result};
@@ -129,14 +130,20 @@ impl ExecutionPlan for CoalescePartitionsExec {
 
                 // spawn independent tasks whose resulting streams (of batches)
                 // are sent to the channel for consumption.
+                let mut join_handles = Vec::with_capacity(input_partitions);
                 for part_i in 0..input_partitions {
-                    spawn_execution(self.input.clone(), sender.clone(), 
part_i);
+                    join_handles.push(spawn_execution(
+                        self.input.clone(),
+                        sender.clone(),
+                        part_i,
+                    ));
                 }
 
                 Ok(Box::pin(MergeStream {
                     input: receiver,
                     schema: self.schema(),
                     baseline_metrics,
+                    drop_helper: AbortOnDropMany(join_handles),
                 }))
             }
         }
@@ -168,7 +175,8 @@ pin_project! {
         schema: SchemaRef,
         #[pin]
         input: mpsc::Receiver<ArrowResult<RecordBatch>>,
-        baseline_metrics: BaselineMetrics
+        baseline_metrics: BaselineMetrics,
+        drop_helper: AbortOnDropMany<()>,
     }
 }
 
@@ -194,11 +202,15 @@ impl RecordBatchStream for MergeStream {
 #[cfg(test)]
 mod tests {
 
+    use arrow::datatypes::{DataType, Field, Schema};
+    use futures::FutureExt;
+
     use super::*;
     use crate::datasource::object_store::local::LocalFileSystem;
-    use crate::physical_plan::common;
     use crate::physical_plan::file_format::CsvExec;
-    use crate::test;
+    use crate::physical_plan::{collect, common};
+    use crate::test::exec::{assert_strong_count_converges_to_zero, 
BlockingExec};
+    use crate::test::{self, assert_is_pending};
 
     #[tokio::test]
     async fn merge() -> Result<()> {
@@ -238,4 +250,24 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
2));
+        let refs = blocking_exec.refs();
+        let coaelesce_partitions_exec =
+            Arc::new(CoalescePartitionsExec::new(blocking_exec));
+
+        let fut = collect(coaelesce_partitions_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/file_format/parquet.rs 
b/datafusion/src/physical_plan/file_format/parquet.rs
index 77eed01..c011f33 100644
--- a/datafusion/src/physical_plan/file_format/parquet.rs
+++ b/datafusion/src/physical_plan/file_format/parquet.rs
@@ -312,7 +312,7 @@ impl ExecutionPlan for ParquetExec {
         let limit = self.limit;
         let object_store = Arc::clone(&self.object_store);
 
-        task::spawn_blocking(move || {
+        let join_handle = task::spawn_blocking(move || {
             if let Err(e) = read_partition(
                 object_store.as_ref(),
                 partition_index,
@@ -328,7 +328,11 @@ impl ExecutionPlan for ParquetExec {
             }
         });
 
-        Ok(RecordBatchReceiverStream::create(&self.schema, response_rx))
+        Ok(RecordBatchReceiverStream::create(
+            &self.schema,
+            response_rx,
+            join_handle,
+        ))
     }
 
     fn fmt_as(
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs 
b/datafusion/src/physical_plan/hash_aggregate.rs
index adeeb0b..33c6827 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -50,6 +50,7 @@ use pin_project_lite::pin_project;
 
 use async_trait::async_trait;
 
+use super::common::AbortOnDropSingle;
 use super::metrics::{
     self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
 };
@@ -339,6 +340,7 @@ pin_project! {
         #[pin]
         output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
         finished: bool,
+        drop_helper: AbortOnDropSingle<()>,
     }
 }
 
@@ -561,7 +563,8 @@ impl GroupedHashAggregateStream {
 
         let schema_clone = schema.clone();
         let elapsed_compute = baseline_metrics.elapsed_compute().clone();
-        tokio::spawn(async move {
+
+        let join_handle = tokio::spawn(async move {
             let result = compute_grouped_hash_aggregate(
                 mode,
                 schema_clone,
@@ -572,13 +575,16 @@ impl GroupedHashAggregateStream {
             )
             .await
             .record_output(&baseline_metrics);
-            tx.send(result)
+
+            // failing here is OK, the receiver is gone and does not care 
about the result
+            tx.send(result).ok();
         });
 
         Self {
             schema,
             output: rx,
             finished: false,
+            drop_helper: AbortOnDropSingle::new(join_handle),
         }
     }
 }
@@ -738,6 +744,7 @@ pin_project! {
         #[pin]
         output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
         finished: bool,
+        drop_helper: AbortOnDropSingle<()>,
     }
 }
 
@@ -789,7 +796,7 @@ impl HashAggregateStream {
 
         let schema_clone = schema.clone();
         let elapsed_compute = baseline_metrics.elapsed_compute().clone();
-        tokio::spawn(async move {
+        let join_handle = tokio::spawn(async move {
             let result = compute_hash_aggregate(
                 mode,
                 schema_clone,
@@ -800,13 +807,15 @@ impl HashAggregateStream {
             .await
             .record_output(&baseline_metrics);
 
-            tx.send(result)
+            // failing here is OK, the receiver is gone and does not care 
about the result
+            tx.send(result).ok();
         });
 
         Self {
             schema,
             output: rx,
             finished: false,
+            drop_helper: AbortOnDropSingle::new(join_handle),
         }
     }
 }
@@ -1005,9 +1014,12 @@ mod tests {
 
     use arrow::array::{Float64Array, UInt32Array};
     use arrow::datatypes::DataType;
+    use futures::FutureExt;
 
     use super::*;
     use crate::physical_plan::expressions::{col, Avg};
+    use crate::test::assert_is_pending;
+    use crate::test::exec::{assert_strong_count_converges_to_zero, 
BlockingExec};
     use crate::{assert_batches_sorted_eq, physical_plan::common};
 
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
@@ -1230,4 +1242,73 @@ mod tests {
 
         check_aggregates(input).await
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel_without_groups() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let groups = vec![];
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+            col("a", &schema)?,
+            "AVG(a)".to_string(),
+            DataType::Float64,
+        ))];
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
+        let refs = blocking_exec.refs();
+        let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
+            AggregateMode::Partial,
+            groups.clone(),
+            aggregates.clone(),
+            blocking_exec,
+            schema,
+        )?);
+
+        let fut = crate::physical_plan::collect(hash_aggregate_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_drop_cancel_with_groups() -> Result<()> {
+        let schema = Arc::new(Schema::new(vec![
+            Field::new("a", DataType::Float32, true),
+            Field::new("b", DataType::Float32, true),
+        ]));
+
+        let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+            vec![(col("a", &schema)?, "a".to_string())];
+
+        let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+            col("b", &schema)?,
+            "AVG(b)".to_string(),
+            DataType::Float64,
+        ))];
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema), 
1));
+        let refs = blocking_exec.refs();
+        let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
+            AggregateMode::Partial,
+            groups.clone(),
+            aggregates.clone(),
+            blocking_exec,
+            schema,
+        )?);
+
+        let fut = crate::physical_plan::collect(hash_aggregate_exec);
+        let mut fut = fut.boxed();
+
+        assert_is_pending(&mut fut);
+        drop(fut);
+        assert_strong_count_converges_to_zero(refs).await;
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/physical_plan/stream.rs 
b/datafusion/src/physical_plan/stream.rs
index 0c29f87..67b7090 100644
--- a/datafusion/src/physical_plan/stream.rs
+++ b/datafusion/src/physical_plan/stream.rs
@@ -21,8 +21,10 @@ use arrow::{
     datatypes::SchemaRef, error::Result as ArrowResult, 
record_batch::RecordBatch,
 };
 use futures::{Stream, StreamExt};
+use tokio::task::JoinHandle;
 use tokio_stream::wrappers::ReceiverStream;
 
+use super::common::AbortOnDropSingle;
 use super::{RecordBatchStream, SendableRecordBatchStream};
 
 /// Adapter for a tokio [`ReceiverStream`] that implements the
@@ -30,7 +32,11 @@ use super::{RecordBatchStream, SendableRecordBatchStream};
 /// interface
 pub struct RecordBatchReceiverStream {
     schema: SchemaRef,
+
     inner: ReceiverStream<ArrowResult<RecordBatch>>,
+
+    #[allow(dead_code)]
+    drop_helper: AbortOnDropSingle<()>,
 }
 
 impl RecordBatchReceiverStream {
@@ -39,10 +45,15 @@ impl RecordBatchReceiverStream {
     pub fn create(
         schema: &SchemaRef,
         rx: tokio::sync::mpsc::Receiver<ArrowResult<RecordBatch>>,
+        join_handle: JoinHandle<()>,
     ) -> SendableRecordBatchStream {
         let schema = schema.clone();
         let inner = ReceiverStream::new(rx);
-        Box::pin(Self { schema, inner })
+        Box::pin(Self {
+            schema,
+            inner,
+            drop_helper: AbortOnDropSingle::new(join_handle),
+        })
     }
 }
 
diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs
index aca3b6e..fd10b9c 100644
--- a/datafusion/src/test/exec.rs
+++ b/datafusion/src/test/exec.rs
@@ -180,7 +180,7 @@ impl ExecutionPlan for MockExec {
         // task simply sends data in order but in a separate
         // thread (to ensure the batches are not available without the
         // DelayedStream yielding).
-        tokio::task::spawn(async move {
+        let join_handle = tokio::task::spawn(async move {
             for batch in data {
                 println!("Sending batch via delayed stream");
                 if let Err(e) = tx.send(batch).await {
@@ -190,7 +190,11 @@ impl ExecutionPlan for MockExec {
         });
 
         // returned stream simply reads off the rx stream
-        Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+        Ok(RecordBatchReceiverStream::create(
+            &self.schema,
+            rx,
+            join_handle,
+        ))
     }
 
     fn fmt_as(
@@ -297,7 +301,7 @@ impl ExecutionPlan for BarrierExec {
         // task simply sends data in order after barrier is reached
         let data = self.data[partition].clone();
         let b = self.barrier.clone();
-        tokio::task::spawn(async move {
+        let join_handle = tokio::task::spawn(async move {
             println!("Partition {} waiting on barrier", partition);
             b.wait().await;
             for batch in data {
@@ -309,7 +313,11 @@ impl ExecutionPlan for BarrierExec {
         });
 
         // returned stream simply reads off the rx stream
-        Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+        Ok(RecordBatchReceiverStream::create(
+            &self.schema,
+            rx,
+            join_handle,
+        ))
     }
 
     fn fmt_as(

Reply via email to