This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 268bdec  Clean up spawned task on `SortStream` drop (#1105)
268bdec is described below

commit 268bdec3a513c8e48bdcf995e253120dc9868fea
Author: Marco Neumann <[email protected]>
AuthorDate: Tue Oct 12 18:42:31 2021 +0200

    Clean up spawned task on `SortStream` drop (#1105)
    
    Ref #1103.
---
 datafusion/Cargo.toml                |   2 +-
 datafusion/src/physical_plan/sort.rs |  56 ++++++++++++++++-
 datafusion/src/test/exec.rs          | 114 ++++++++++++++++++++++++++++++++++-
 3 files changed, 168 insertions(+), 4 deletions(-)

diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index ea9ca21..ecc434a 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -58,7 +58,7 @@ num_cpus = "1.13.0"
 chrono = "0.4"
 async-trait = "0.1.41"
 futures = "0.3"
-pin-project-lite= "^0.2.0"
+pin-project-lite= "^0.2.7"
 tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", 
"sync", "fs"] }
 tokio-stream = "0.1"
 log = "^0.4"
diff --git a/datafusion/src/physical_plan/sort.rs 
b/datafusion/src/physical_plan/sort.rs
index b732797..3032556 100644
--- a/datafusion/src/physical_plan/sort.rs
+++ b/datafusion/src/physical_plan/sort.rs
@@ -40,6 +40,7 @@ use std::any::Any;
 use std::pin::Pin;
 use std::sync::Arc;
 use std::task::{Context, Poll};
+use tokio::task::JoinHandle;
 
 /// Sort execution plan
 #[derive(Debug)]
@@ -228,6 +229,13 @@ pin_project! {
         output: 
futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
         finished: bool,
         schema: SchemaRef,
+        join_handle: JoinHandle<()>,
+    }
+
+    impl PinnedDrop for SortStream {
+        fn drop(this: Pin<&mut Self>) {
+            this.join_handle.abort();
+        }
     }
 }
 
@@ -239,7 +247,7 @@ impl SortStream {
     ) -> Self {
         let (tx, rx) = futures::channel::oneshot::channel();
         let schema = input.schema();
-        tokio::spawn(async move {
+        let join_handle = tokio::spawn(async move {
             let schema = input.schema();
             let sorted_batch = common::collect(input)
                 .await
@@ -257,13 +265,15 @@ impl SortStream {
                     Ok(result)
                 });
 
-            tx.send(sorted_batch)
+            // failing here is OK, the receiver is gone and does not care 
about the result
+            tx.send(sorted_batch).ok();
         });
 
         Self {
             output: rx,
             finished: false,
             schema,
+            join_handle,
         }
     }
 }
@@ -305,6 +315,8 @@ impl RecordBatchStream for SortStream {
 
 #[cfg(test)]
 mod tests {
+    use std::sync::Weak;
+
     use super::*;
     use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
     use crate::physical_plan::expressions::col;
@@ -314,8 +326,10 @@ mod tests {
         csv::{CsvExec, CsvReadOptions},
     };
     use crate::test;
+    use crate::test::exec::BlockingExec;
     use arrow::array::*;
     use arrow::datatypes::*;
+    use futures::FutureExt;
 
     #[tokio::test]
     async fn test_sort() -> Result<()> {
@@ -474,4 +488,42 @@ mod tests {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_drop_cancel() -> Result<()> {
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, 
true)]));
+
+        let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema)));
+        let refs = blocking_exec.refs();
+        let sort_exec = Arc::new(SortExec::try_new(
+            vec![PhysicalSortExpr {
+                expr: col("a", &schema)?,
+                options: SortOptions::default(),
+            }],
+            blocking_exec,
+        )?);
+
+        let fut = collect(sort_exec);
+        let mut fut = fut.boxed();
+
+        let waker = futures::task::noop_waker();
+        let mut cx = futures::task::Context::from_waker(&waker);
+        let poll = fut.poll_unpin(&mut cx);
+
+        assert!(poll.is_pending());
+        drop(fut);
+        tokio::time::timeout(std::time::Duration::from_secs(10), async {
+            loop {
+                if dbg!(Weak::strong_count(&refs)) == 0 {
+                    break;
+                }
+                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
+            }
+        })
+        .await
+        .unwrap();
+
+        Ok(())
+    }
 }
diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs
index 688cff8..252168f 100644
--- a/datafusion/src/test/exec.rs
+++ b/datafusion/src/test/exec.rs
@@ -20,7 +20,8 @@
 use async_trait::async_trait;
 use std::{
     any::Any,
-    sync::Arc,
+    pin::Pin,
+    sync::{Arc, Weak},
     task::{Context, Poll},
 };
 use tokio::sync::Barrier;
@@ -472,3 +473,114 @@ impl ExecutionPlan for StatisticsExec {
         }
     }
 }
+
+/// Execution plan that emits streams that block forever.
+///
+/// This is useful to test shutdown / cancelation behavior of certain 
execution plans.
+#[derive(Debug)]
+pub struct BlockingExec {
+    /// Schema that is mocked by this plan.
+    schema: SchemaRef,
+
+    /// Ref-counting helper to check if the plan and the produced stream are 
still in memory.
+    refs: Arc<()>,
+}
+
+impl BlockingExec {
+    /// Create new [`BlockingExec`] with a give schema.
+    pub fn new(schema: SchemaRef) -> Self {
+        Self {
+            schema,
+            refs: Default::default(),
+        }
+    }
+
+    /// Weak pointer that can be used for ref-counting this execution plan and 
its streams.
+    ///
+    /// Use [`Weak::strong_count`] to determine if the plan itself and its 
streams are dropped (should be 0 in that
+    /// case). Note that tokio might take some time to cancel spawned tasks, 
so you need to wrap this check into a retry
+    /// loop.
+    pub fn refs(&self) -> Weak<()> {
+        Arc::downgrade(&self.refs)
+    }
+}
+
+#[async_trait]
+impl ExecutionPlan for BlockingExec {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn schema(&self) -> SchemaRef {
+        Arc::clone(&self.schema)
+    }
+
+    fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+        // this is a leaf node and has no children
+        vec![]
+    }
+
+    fn output_partitioning(&self) -> Partitioning {
+        Partitioning::UnknownPartitioning(1)
+    }
+
+    fn with_new_children(
+        &self,
+        _: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        Err(DataFusionError::Internal(format!(
+            "Children cannot be replaced in {:?}",
+            self
+        )))
+    }
+
+    async fn execute(&self, _partition: usize) -> 
Result<SendableRecordBatchStream> {
+        Ok(Box::pin(BlockingStream {
+            schema: Arc::clone(&self.schema),
+            refs: Arc::clone(&self.refs),
+        }))
+    }
+
+    fn fmt_as(
+        &self,
+        t: DisplayFormatType,
+        f: &mut std::fmt::Formatter,
+    ) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default => {
+                write!(f, "BlockingExec",)
+            }
+        }
+    }
+
+    fn statistics(&self) -> Statistics {
+        unimplemented!()
+    }
+}
+
+/// A [`RecordBatchStream`] that is pending forever.
+#[derive(Debug)]
+pub struct BlockingStream {
+    /// Schema mocked by this stream.
+    schema: SchemaRef,
+
+    /// Ref-counting helper to check if the stream are still in memory.
+    refs: Arc<()>,
+}
+
+impl Stream for BlockingStream {
+    type Item = ArrowResult<RecordBatch>;
+
+    fn poll_next(
+        self: Pin<&mut Self>,
+        _cx: &mut Context<'_>,
+    ) -> Poll<Option<Self::Item>> {
+        Poll::Pending
+    }
+}
+
+impl RecordBatchStream for BlockingStream {
+    fn schema(&self) -> SchemaRef {
+        Arc::clone(&self.schema)
+    }
+}

Reply via email to