This is an automated email from the ASF dual-hosted git repository.
houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 831e07d Clean up spawned task on drop for `AnalyzeExec`,
`CoalescePartitionsExec`, `HashAggregateExec` (#1121)
831e07d is described below
commit 831e07debc4f136f2e47e126b20e441f7606bd74
Author: Marco Neumann <[email protected]>
AuthorDate: Sat Oct 16 06:48:17 2021 +0200
Clean up spawned task on drop for `AnalyzeExec`, `CoalescePartitionsExec`,
`HashAggregateExec` (#1121)
* Clean up spawned task on `HashAggregateExec` drop
Ref #1103.
* Clean up spawned task on `CoalescePartitionsExec` drop
Ref #1103.
* Clean up spawned task on `AnalyzeExec` drop
As a side-effect, cancelation now works with all users of
`RecordBatchReceiverStream` (e.g. `ParquetExec`) but there the effect
should be slightly less important.
Ref #1103.
---
datafusion/src/physical_plan/analyze.rs | 43 ++++++++++-
.../src/physical_plan/coalesce_partitions.rs | 40 +++++++++-
.../src/physical_plan/file_format/parquet.rs | 8 +-
datafusion/src/physical_plan/hash_aggregate.rs | 89 +++++++++++++++++++++-
datafusion/src/physical_plan/stream.rs | 13 +++-
datafusion/src/test/exec.rs | 16 +++-
6 files changed, 192 insertions(+), 17 deletions(-)
diff --git a/datafusion/src/physical_plan/analyze.rs
b/datafusion/src/physical_plan/analyze.rs
index e68acc5..c9e316e 100644
--- a/datafusion/src/physical_plan/analyze.rs
+++ b/datafusion/src/physical_plan/analyze.rs
@@ -125,7 +125,7 @@ impl ExecutionPlan for AnalyzeExec {
// Task reads batches the input and when complete produce a
// RecordBatch with a report that is written to `tx` when done
- tokio::task::spawn(async move {
+ let join_handle = tokio::task::spawn(async move {
let start = Instant::now();
let mut total_rows = 0;
@@ -194,7 +194,11 @@ impl ExecutionPlan for AnalyzeExec {
tx.send(maybe_batch).await.ok();
});
- Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+ Ok(RecordBatchReceiverStream::create(
+ &self.schema,
+ rx,
+ join_handle,
+ ))
}
fn fmt_as(
@@ -214,3 +218,38 @@ impl ExecutionPlan for AnalyzeExec {
Statistics::default()
}
}
+
+#[cfg(test)]
+mod tests {
+ use arrow::datatypes::{DataType, Field, Schema};
+ use futures::FutureExt;
+
+ use crate::{
+ physical_plan::collect,
+ test::{
+ assert_is_pending,
+ exec::{assert_strong_count_converges_to_zero, BlockingExec},
+ },
+ };
+
+ use super::*;
+
+ #[tokio::test]
+ async fn test_drop_cancel() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
+ let refs = blocking_exec.refs();
+ let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec,
schema));
+
+ let fut = collect(analyze_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
+}
diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs
b/datafusion/src/physical_plan/coalesce_partitions.rs
index a106838..1fd18d2 100644
--- a/datafusion/src/physical_plan/coalesce_partitions.rs
+++ b/datafusion/src/physical_plan/coalesce_partitions.rs
@@ -30,6 +30,7 @@ use async_trait::async_trait;
use arrow::record_batch::RecordBatch;
use arrow::{datatypes::SchemaRef, error::Result as ArrowResult};
+use super::common::AbortOnDropMany;
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use super::{RecordBatchStream, Statistics};
use crate::error::{DataFusionError, Result};
@@ -129,14 +130,20 @@ impl ExecutionPlan for CoalescePartitionsExec {
// spawn independent tasks whose resulting streams (of batches)
// are sent to the channel for consumption.
+ let mut join_handles = Vec::with_capacity(input_partitions);
for part_i in 0..input_partitions {
- spawn_execution(self.input.clone(), sender.clone(),
part_i);
+ join_handles.push(spawn_execution(
+ self.input.clone(),
+ sender.clone(),
+ part_i,
+ ));
}
Ok(Box::pin(MergeStream {
input: receiver,
schema: self.schema(),
baseline_metrics,
+ drop_helper: AbortOnDropMany(join_handles),
}))
}
}
@@ -168,7 +175,8 @@ pin_project! {
schema: SchemaRef,
#[pin]
input: mpsc::Receiver<ArrowResult<RecordBatch>>,
- baseline_metrics: BaselineMetrics
+ baseline_metrics: BaselineMetrics,
+ drop_helper: AbortOnDropMany<()>,
}
}
@@ -194,11 +202,15 @@ impl RecordBatchStream for MergeStream {
#[cfg(test)]
mod tests {
+ use arrow::datatypes::{DataType, Field, Schema};
+ use futures::FutureExt;
+
use super::*;
use crate::datasource::object_store::local::LocalFileSystem;
- use crate::physical_plan::common;
use crate::physical_plan::file_format::CsvExec;
- use crate::test;
+ use crate::physical_plan::{collect, common};
+ use crate::test::exec::{assert_strong_count_converges_to_zero,
BlockingExec};
+ use crate::test::{self, assert_is_pending};
#[tokio::test]
async fn merge() -> Result<()> {
@@ -238,4 +250,24 @@ mod tests {
Ok(())
}
+
+ #[tokio::test]
+ async fn test_drop_cancel() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
2));
+ let refs = blocking_exec.refs();
+ let coaelesce_partitions_exec =
+ Arc::new(CoalescePartitionsExec::new(blocking_exec));
+
+ let fut = collect(coaelesce_partitions_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/file_format/parquet.rs
b/datafusion/src/physical_plan/file_format/parquet.rs
index 77eed01..c011f33 100644
--- a/datafusion/src/physical_plan/file_format/parquet.rs
+++ b/datafusion/src/physical_plan/file_format/parquet.rs
@@ -312,7 +312,7 @@ impl ExecutionPlan for ParquetExec {
let limit = self.limit;
let object_store = Arc::clone(&self.object_store);
- task::spawn_blocking(move || {
+ let join_handle = task::spawn_blocking(move || {
if let Err(e) = read_partition(
object_store.as_ref(),
partition_index,
@@ -328,7 +328,11 @@ impl ExecutionPlan for ParquetExec {
}
});
- Ok(RecordBatchReceiverStream::create(&self.schema, response_rx))
+ Ok(RecordBatchReceiverStream::create(
+ &self.schema,
+ response_rx,
+ join_handle,
+ ))
}
fn fmt_as(
diff --git a/datafusion/src/physical_plan/hash_aggregate.rs
b/datafusion/src/physical_plan/hash_aggregate.rs
index adeeb0b..33c6827 100644
--- a/datafusion/src/physical_plan/hash_aggregate.rs
+++ b/datafusion/src/physical_plan/hash_aggregate.rs
@@ -50,6 +50,7 @@ use pin_project_lite::pin_project;
use async_trait::async_trait;
+use super::common::AbortOnDropSingle;
use super::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
@@ -339,6 +340,7 @@ pin_project! {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
+ drop_helper: AbortOnDropSingle<()>,
}
}
@@ -561,7 +563,8 @@ impl GroupedHashAggregateStream {
let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
- tokio::spawn(async move {
+
+ let join_handle = tokio::spawn(async move {
let result = compute_grouped_hash_aggregate(
mode,
schema_clone,
@@ -572,13 +575,16 @@ impl GroupedHashAggregateStream {
)
.await
.record_output(&baseline_metrics);
- tx.send(result)
+
+ // failing here is OK, the receiver is gone and does not care
about the result
+ tx.send(result).ok();
});
Self {
schema,
output: rx,
finished: false,
+ drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
@@ -738,6 +744,7 @@ pin_project! {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
+ drop_helper: AbortOnDropSingle<()>,
}
}
@@ -789,7 +796,7 @@ impl HashAggregateStream {
let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
- tokio::spawn(async move {
+ let join_handle = tokio::spawn(async move {
let result = compute_hash_aggregate(
mode,
schema_clone,
@@ -800,13 +807,15 @@ impl HashAggregateStream {
.await
.record_output(&baseline_metrics);
- tx.send(result)
+ // failing here is OK, the receiver is gone and does not care
about the result
+ tx.send(result).ok();
});
Self {
schema,
output: rx,
finished: false,
+ drop_helper: AbortOnDropSingle::new(join_handle),
}
}
}
@@ -1005,9 +1014,12 @@ mod tests {
use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::DataType;
+ use futures::FutureExt;
use super::*;
use crate::physical_plan::expressions::{col, Avg};
+ use crate::test::assert_is_pending;
+ use crate::test::exec::{assert_strong_count_converges_to_zero,
BlockingExec};
use crate::{assert_batches_sorted_eq, physical_plan::common};
use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec;
@@ -1230,4 +1242,73 @@ mod tests {
check_aggregates(input).await
}
+
+ #[tokio::test]
+ async fn test_drop_cancel_without_groups() -> Result<()> {
+ let schema =
+ Arc::new(Schema::new(vec![Field::new("a", DataType::Float32,
true)]));
+
+ let groups = vec![];
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+ col("a", &schema)?,
+ "AVG(a)".to_string(),
+ DataType::Float64,
+ ))];
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
+ let refs = blocking_exec.refs();
+ let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
+ AggregateMode::Partial,
+ groups.clone(),
+ aggregates.clone(),
+ blocking_exec,
+ schema,
+ )?);
+
+ let fut = crate::physical_plan::collect(hash_aggregate_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_drop_cancel_with_groups() -> Result<()> {
+ let schema = Arc::new(Schema::new(vec![
+ Field::new("a", DataType::Float32, true),
+ Field::new("b", DataType::Float32, true),
+ ]));
+
+ let groups: Vec<(Arc<dyn PhysicalExpr>, String)> =
+ vec![(col("a", &schema)?, "a".to_string())];
+
+ let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![Arc::new(Avg::new(
+ col("b", &schema)?,
+ "AVG(b)".to_string(),
+ DataType::Float64,
+ ))];
+
+ let blocking_exec = Arc::new(BlockingExec::new(Arc::clone(&schema),
1));
+ let refs = blocking_exec.refs();
+ let hash_aggregate_exec = Arc::new(HashAggregateExec::try_new(
+ AggregateMode::Partial,
+ groups.clone(),
+ aggregates.clone(),
+ blocking_exec,
+ schema,
+ )?);
+
+ let fut = crate::physical_plan::collect(hash_aggregate_exec);
+ let mut fut = fut.boxed();
+
+ assert_is_pending(&mut fut);
+ drop(fut);
+ assert_strong_count_converges_to_zero(refs).await;
+
+ Ok(())
+ }
}
diff --git a/datafusion/src/physical_plan/stream.rs
b/datafusion/src/physical_plan/stream.rs
index 0c29f87..67b7090 100644
--- a/datafusion/src/physical_plan/stream.rs
+++ b/datafusion/src/physical_plan/stream.rs
@@ -21,8 +21,10 @@ use arrow::{
datatypes::SchemaRef, error::Result as ArrowResult,
record_batch::RecordBatch,
};
use futures::{Stream, StreamExt};
+use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
+use super::common::AbortOnDropSingle;
use super::{RecordBatchStream, SendableRecordBatchStream};
/// Adapter for a tokio [`ReceiverStream`] that implements the
@@ -30,7 +32,11 @@ use super::{RecordBatchStream, SendableRecordBatchStream};
/// interface
pub struct RecordBatchReceiverStream {
schema: SchemaRef,
+
inner: ReceiverStream<ArrowResult<RecordBatch>>,
+
+ #[allow(dead_code)]
+ drop_helper: AbortOnDropSingle<()>,
}
impl RecordBatchReceiverStream {
@@ -39,10 +45,15 @@ impl RecordBatchReceiverStream {
pub fn create(
schema: &SchemaRef,
rx: tokio::sync::mpsc::Receiver<ArrowResult<RecordBatch>>,
+ join_handle: JoinHandle<()>,
) -> SendableRecordBatchStream {
let schema = schema.clone();
let inner = ReceiverStream::new(rx);
- Box::pin(Self { schema, inner })
+ Box::pin(Self {
+ schema,
+ inner,
+ drop_helper: AbortOnDropSingle::new(join_handle),
+ })
}
}
diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs
index aca3b6e..fd10b9c 100644
--- a/datafusion/src/test/exec.rs
+++ b/datafusion/src/test/exec.rs
@@ -180,7 +180,7 @@ impl ExecutionPlan for MockExec {
// task simply sends data in order but in a separate
// thread (to ensure the batches are not available without the
// DelayedStream yielding).
- tokio::task::spawn(async move {
+ let join_handle = tokio::task::spawn(async move {
for batch in data {
println!("Sending batch via delayed stream");
if let Err(e) = tx.send(batch).await {
@@ -190,7 +190,11 @@ impl ExecutionPlan for MockExec {
});
// returned stream simply reads off the rx stream
- Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+ Ok(RecordBatchReceiverStream::create(
+ &self.schema,
+ rx,
+ join_handle,
+ ))
}
fn fmt_as(
@@ -297,7 +301,7 @@ impl ExecutionPlan for BarrierExec {
// task simply sends data in order after barrier is reached
let data = self.data[partition].clone();
let b = self.barrier.clone();
- tokio::task::spawn(async move {
+ let join_handle = tokio::task::spawn(async move {
println!("Partition {} waiting on barrier", partition);
b.wait().await;
for batch in data {
@@ -309,7 +313,11 @@ impl ExecutionPlan for BarrierExec {
});
// returned stream simply reads off the rx stream
- Ok(RecordBatchReceiverStream::create(&self.schema, rx))
+ Ok(RecordBatchReceiverStream::create(
+ &self.schema,
+ rx,
+ join_handle,
+ ))
}
fn fmt_as(