This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 14264d2c39 fix: use `JoinSet` to make spawned tasks cancel-safe (#9318)
14264d2c39 is described below
commit 14264d2c3947e432f71bfe0af1a3dbafbb6ee686
Author: Artem Medvedev <[email protected]>
AuthorDate: Tue Feb 27 14:11:59 2024 +0100
fix: use `JoinSet` to make spawned tasks cancel-safe (#9318)
* fix: use `JoinSet` to make spawned tasks cancel-safe
* feat: drop `AbortOnDropSingle` and `AbortOnDropMany`
* style: doc lint
* fix: ordering of the tasks in `RepartitionExec`
* fix: replace spawn_blocking with JoinSet
* style: disallow spawn methods
* fixes: preserve ordering of tasks
* style: allow spawning in tests
* chore: exclude clippy.toml from rat
* chore: typo
* feat: introduce `SpawnedTask`
* revert outdated comment
* switch to SpawnedTask missed outdated part
* doc: improve reason for disallowed-method
---
clippy.toml | 4 ++
datafusion/core/src/dataframe/mod.rs | 1 +
.../core/src/datasource/file_format/arrow.rs | 2 +-
.../core/src/datasource/file_format/parquet.rs | 51 ++++++++-------
.../core/src/datasource/file_format/write/demux.rs | 12 ++--
.../datasource/file_format/write/orchestration.rs | 29 ++++-----
datafusion/core/src/datasource/stream.rs | 9 ++-
datafusion/core/src/execution/context/mod.rs | 1 +
datafusion/core/tests/fifo.rs | 2 +
.../fuzz_cases/sort_preserving_repartition_fuzz.rs | 1 +
datafusion/core/tests/fuzz_cases/window_fuzz.rs | 1 +
datafusion/physical-plan/src/common.rs | 73 ++++++++++------------
datafusion/physical-plan/src/lib.rs | 10 +--
datafusion/physical-plan/src/repartition/mod.rs | 46 +++++++-------
datafusion/physical-plan/src/sorts/sort.rs | 7 +--
datafusion/sqllogictest/bin/sqllogictests.rs | 1 +
dev/release/rat_exclude_files.txt | 3 +-
17 files changed, 129 insertions(+), 124 deletions(-)
diff --git a/clippy.toml b/clippy.toml
new file mode 100644
index 0000000000..c6c754e440
--- /dev/null
+++ b/clippy.toml
@@ -0,0 +1,4 @@
+disallowed-methods = [
+ { path = "tokio::task::spawn", reason = "To provide cancel-safety, use
`SpawnedTask::spawn` instead
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+ { path = "tokio::task::spawn_blocking", reason = "To provide
cancel-safety, use `SpawnedTask::spawn` instead
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+]
diff --git a/datafusion/core/src/dataframe/mod.rs
b/datafusion/core/src/dataframe/mod.rs
index 3a60d57f66..c04247210d 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2172,6 +2172,7 @@ mod tests {
}
#[tokio::test]
+ #[allow(clippy::disallowed_methods)]
async fn sendable() {
let df = test_table().await.unwrap();
// dataframes should be sendable between threads/tasks
diff --git a/datafusion/core/src/datasource/file_format/arrow.rs
b/datafusion/core/src/datasource/file_format/arrow.rs
index ead2db5a10..d5f07d11be 100644
--- a/datafusion/core/src/datasource/file_format/arrow.rs
+++ b/datafusion/core/src/datasource/file_format/arrow.rs
@@ -295,7 +295,7 @@ impl DataSink for ArrowFileSink {
}
}
- match demux_task.await {
+ match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs
b/datafusion/core/src/datasource/file_format/parquet.rs
index 89ec81630c..7398501153 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -32,7 +32,7 @@ use std::fmt::Debug;
use std::sync::Arc;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver, Sender};
-use tokio::task::{JoinHandle, JoinSet};
+use tokio::task::JoinSet;
use crate::datasource::file_format::file_compression_type::FileCompressionType;
use crate::datasource::statistics::{create_max_min_accs, get_col_stats};
@@ -42,6 +42,7 @@ use bytes::{BufMut, BytesMut};
use datafusion_common::{exec_err, not_impl_err, DataFusionError, FileType};
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{PhysicalExpr, PhysicalSortRequirement};
+use datafusion_physical_plan::common::SpawnedTask;
use futures::{StreamExt, TryStreamExt};
use hashbrown::HashMap;
use object_store::path::Path;
@@ -728,7 +729,7 @@ impl DataSink for ParquetSink {
}
}
- match demux_task.await {
+ match demux_task.join().await {
Ok(r) => r?,
Err(e) => {
if e.is_panic() {
@@ -738,6 +739,7 @@ impl DataSink for ParquetSink {
}
}
}
+
Ok(row_count as u64)
}
}
@@ -754,8 +756,9 @@ async fn column_serializer_task(
Ok(writer)
}
-type ColumnJoinHandle = JoinHandle<Result<ArrowColumnWriter>>;
+type ColumnWriterTask = SpawnedTask<Result<ArrowColumnWriter>>;
type ColSender = Sender<ArrowLeafColumn>;
+
/// Spawns a parallel serialization task for each column
/// Returns join handles for each columns serialization task along with a send
channel
/// to send arrow arrays to each serialization task.
@@ -763,23 +766,24 @@ fn spawn_column_parallel_row_group_writer(
schema: Arc<Schema>,
parquet_props: Arc<WriterProperties>,
max_buffer_size: usize,
-) -> Result<(Vec<ColumnJoinHandle>, Vec<ColSender>)> {
+) -> Result<(Vec<ColumnWriterTask>, Vec<ColSender>)> {
let schema_desc = arrow_to_parquet_schema(&schema)?;
let col_writers = get_column_writers(&schema_desc, &parquet_props,
&schema)?;
let num_columns = col_writers.len();
- let mut col_writer_handles = Vec::with_capacity(num_columns);
+ let mut col_writer_tasks = Vec::with_capacity(num_columns);
let mut col_array_channels = Vec::with_capacity(num_columns);
for writer in col_writers.into_iter() {
// Buffer size of this channel limits the number of arrays queued up
for column level serialization
let (send_array, recieve_array) =
mpsc::channel::<ArrowLeafColumn>(max_buffer_size);
col_array_channels.push(send_array);
- col_writer_handles
- .push(tokio::spawn(column_serializer_task(recieve_array, writer)))
+
+ let task = SpawnedTask::spawn(column_serializer_task(recieve_array,
writer));
+ col_writer_tasks.push(task);
}
- Ok((col_writer_handles, col_array_channels))
+ Ok((col_writer_tasks, col_array_channels))
}
/// Settings related to writing parquet files in parallel
@@ -820,14 +824,14 @@ async fn send_arrays_to_col_writers(
/// Spawns a tokio task which joins the parallel column writer tasks,
/// and finalizes the row group
fn spawn_rg_join_and_finalize_task(
- column_writer_handles: Vec<JoinHandle<Result<ArrowColumnWriter>>>,
+ column_writer_tasks: Vec<ColumnWriterTask>,
rg_rows: usize,
-) -> JoinHandle<RBStreamSerializeResult> {
- tokio::spawn(async move {
- let num_cols = column_writer_handles.len();
+) -> SpawnedTask<RBStreamSerializeResult> {
+ SpawnedTask::spawn(async move {
+ let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
- for handle in column_writer_handles.into_iter() {
- match handle.await {
+ for task in column_writer_tasks.into_iter() {
+ match task.join().await {
Ok(r) => {
let w = r?;
finalized_rg.push(w.close()?);
@@ -856,12 +860,12 @@ fn spawn_rg_join_and_finalize_task(
/// given by n_columns * num_row_groups.
fn spawn_parquet_parallel_serialization_task(
mut data: Receiver<RecordBatch>,
- serialize_tx: Sender<JoinHandle<RBStreamSerializeResult>>,
+ serialize_tx: Sender<SpawnedTask<RBStreamSerializeResult>>,
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
parallel_options: ParallelParquetWriterOptions,
-) -> JoinHandle<Result<(), DataFusionError>> {
- tokio::spawn(async move {
+) -> SpawnedTask<Result<(), DataFusionError>> {
+ SpawnedTask::spawn(async move {
let max_buffer_rb =
parallel_options.max_buffered_record_batches_per_stream;
let max_row_group_rows = writer_props.max_row_group_size();
let (mut column_writer_handles, mut col_array_channels) =
@@ -931,7 +935,7 @@ fn spawn_parquet_parallel_serialization_task(
/// Consume RowGroups serialized by other parallel tasks and concatenate them
in
/// to the final parquet file, while flushing finalized bytes to an
[ObjectStore]
async fn concatenate_parallel_row_groups(
- mut serialize_rx: Receiver<JoinHandle<RBStreamSerializeResult>>,
+ mut serialize_rx: Receiver<SpawnedTask<RBStreamSerializeResult>>,
schema: Arc<Schema>,
writer_props: Arc<WriterProperties>,
mut object_store_writer: AbortableWrite<Box<dyn AsyncWrite + Send +
Unpin>>,
@@ -947,9 +951,8 @@ async fn concatenate_parallel_row_groups(
let mut row_count = 0;
- while let Some(handle) = serialize_rx.recv().await {
- let join_result = handle.await;
- match join_result {
+ while let Some(task) = serialize_rx.recv().await {
+ match task.join().await {
Ok(result) => {
let mut rg_out = parquet_writer.next_row_group()?;
let (serialized_columns, cnt) = result?;
@@ -999,7 +1002,7 @@ async fn output_single_parquet_file_parallelized(
let max_rowgroups = parallel_options.max_parallel_row_groups;
// Buffer size of this channel limits maximum number of RowGroups being
worked on in parallel
let (serialize_tx, serialize_rx) =
- mpsc::channel::<JoinHandle<RBStreamSerializeResult>>(max_rowgroups);
+ mpsc::channel::<SpawnedTask<RBStreamSerializeResult>>(max_rowgroups);
let arc_props = Arc::new(parquet_props.clone());
let launch_serialization_task = spawn_parquet_parallel_serialization_task(
@@ -1017,7 +1020,7 @@ async fn output_single_parquet_file_parallelized(
)
.await?;
- match launch_serialization_task.await {
+ match launch_serialization_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err(e),
Err(e) => {
@@ -1027,7 +1030,7 @@ async fn output_single_parquet_file_parallelized(
unreachable!()
}
}
- };
+ }
Ok(row_count)
}
diff --git a/datafusion/core/src/datasource/file_format/write/demux.rs
b/datafusion/core/src/datasource/file_format/write/demux.rs
index 8bccf3d71c..d70b4811da 100644
--- a/datafusion/core/src/datasource/file_format/write/demux.rs
+++ b/datafusion/core/src/datasource/file_format/write/demux.rs
@@ -41,8 +41,8 @@ use object_store::path::Path;
use rand::distributions::DistString;
+use datafusion_physical_plan::common::SpawnedTask;
use tokio::sync::mpsc::{self, Receiver, Sender, UnboundedReceiver,
UnboundedSender};
-use tokio::task::JoinHandle;
type RecordBatchReceiver = Receiver<RecordBatch>;
type DemuxedStreamReceiver = UnboundedReceiver<(Path, RecordBatchReceiver)>;
@@ -76,15 +76,15 @@ pub(crate) fn start_demuxer_task(
partition_by: Option<Vec<(String, DataType)>>,
base_output_path: ListingTableUrl,
file_extension: String,
-) -> (JoinHandle<Result<()>>, DemuxedStreamReceiver) {
- let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
+) -> (SpawnedTask<Result<()>>, DemuxedStreamReceiver) {
+ let (tx, rx) = mpsc::unbounded_channel();
let context = context.clone();
let single_file_output = !base_output_path.is_collection();
- let task: JoinHandle<std::result::Result<(), DataFusionError>> = match
partition_by {
+ let task = match partition_by {
Some(parts) => {
// There could be an arbitrarily large number of parallel hive
style partitions being written to, so we cannot
// bound this channel without risking a deadlock.
- tokio::spawn(async move {
+ SpawnedTask::spawn(async move {
hive_style_partitions_demuxer(
tx,
input,
@@ -96,7 +96,7 @@ pub(crate) fn start_demuxer_task(
.await
})
}
- None => tokio::spawn(async move {
+ None => SpawnedTask::spawn(async move {
row_count_demuxer(
tx,
input,
diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs
b/datafusion/core/src/datasource/file_format/write/orchestration.rs
index 1a3042cbc0..05406d3751 100644
--- a/datafusion/core/src/datasource/file_format/write/orchestration.rs
+++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs
@@ -33,10 +33,11 @@ use datafusion_common::{internal_datafusion_err,
internal_err, DataFusionError};
use datafusion_execution::TaskContext;
use bytes::Bytes;
+use datafusion_physical_plan::common::SpawnedTask;
+use futures::try_join;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
-use tokio::task::{JoinHandle, JoinSet};
-use tokio::try_join;
+use tokio::task::JoinSet;
type WriterType = AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>;
type SerializerType = Arc<dyn BatchSerializer>;
@@ -51,14 +52,14 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
) -> std::result::Result<(WriterType, u64), (WriterType, DataFusionError)> {
let (tx, mut rx) =
- mpsc::channel::<JoinHandle<Result<(usize, Bytes),
DataFusionError>>>(100);
- let serialize_task = tokio::spawn(async move {
+ mpsc::channel::<SpawnedTask<Result<(usize, Bytes),
DataFusionError>>>(100);
+ let serialize_task = SpawnedTask::spawn(async move {
// Some serializers (like CSV) handle the first batch differently than
// subsequent batches, so we track that here.
let mut initial = true;
while let Some(batch) = data_rx.recv().await {
let serializer_clone = serializer.clone();
- let handle = tokio::spawn(async move {
+ let task = SpawnedTask::spawn(async move {
let num_rows = batch.num_rows();
let bytes = serializer_clone.serialize(batch, initial)?;
Ok((num_rows, bytes))
@@ -66,7 +67,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
if initial {
initial = false;
}
- tx.send(handle).await.map_err(|_| {
+ tx.send(task).await.map_err(|_| {
internal_datafusion_err!("Unknown error writing to object
store")
})?;
}
@@ -74,8 +75,8 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
});
let mut row_count = 0;
- while let Some(handle) = rx.recv().await {
- match handle.await {
+ while let Some(task) = rx.recv().await {
+ match task.join().await {
Ok(Ok((cnt, bytes))) => {
match writer.write_all(&bytes).await {
Ok(_) => (),
@@ -106,7 +107,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
}
}
- match serialize_task.await {
+ match serialize_task.join().await {
Ok(Ok(_)) => (),
Ok(Err(e)) => return Err((writer, e)),
Err(_) => {
@@ -115,7 +116,7 @@ pub(crate) async fn serialize_rb_stream_to_object_store(
internal_datafusion_err!("Unknown error writing to object
store"),
))
}
- };
+ }
Ok((writer, row_count as u64))
}
@@ -241,9 +242,9 @@ pub(crate) async fn stateless_multipart_put(
.execution
.max_buffered_batches_per_output_file;
- let (tx_file_bundle, rx_file_bundle) =
tokio::sync::mpsc::channel(rb_buffer_size / 2);
+ let (tx_file_bundle, rx_file_bundle) = mpsc::channel(rb_buffer_size / 2);
let (tx_row_cnt, rx_row_cnt) = tokio::sync::oneshot::channel();
- let write_coordinater_task = tokio::spawn(async move {
+ let write_coordinator_task = SpawnedTask::spawn(async move {
stateless_serialize_and_write_files(rx_file_bundle, tx_row_cnt).await
});
while let Some((location, rb_stream)) = file_stream_rx.recv().await {
@@ -260,10 +261,10 @@ pub(crate) async fn stateless_multipart_put(
})?;
}
- // Signal to the write coordinater that no more files are coming
+ // Signal to the write coordinator that no more files are coming
drop(tx_file_bundle);
- match try_join!(write_coordinater_task, demux_task) {
+ match try_join!(write_coordinator_task.join(), demux_task.join()) {
Ok((r1, r2)) => {
r1?;
r2?;
diff --git a/datafusion/core/src/datasource/stream.rs
b/datafusion/core/src/datasource/stream.rs
index 830cd7a07e..6dc59e4a5c 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -29,12 +29,11 @@ use arrow_array::{RecordBatch, RecordBatchReader,
RecordBatchWriter};
use arrow_schema::SchemaRef;
use async_trait::async_trait;
use futures::StreamExt;
-use tokio::task::spawn_blocking;
use datafusion_common::{plan_err, Constraints, DataFusionError, Result};
use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::{CreateExternalTable, Expr, TableType};
-use datafusion_physical_plan::common::AbortOnDropSingle;
+use datafusion_physical_plan::common::SpawnedTask;
use datafusion_physical_plan::insert::{DataSink, FileSinkExec};
use datafusion_physical_plan::metrics::MetricsSet;
use datafusion_physical_plan::stream::RecordBatchReceiverStreamBuilder;
@@ -344,7 +343,7 @@ impl DataSink for StreamWrite {
let config = self.0.clone();
let (sender, mut receiver) =
tokio::sync::mpsc::channel::<RecordBatch>(2);
// Note: FIFO Files support poll so this could use AsyncFd
- let write = AbortOnDropSingle::new(spawn_blocking(move || {
+ let write_task = SpawnedTask::spawn_blocking(move || {
let mut count = 0_u64;
let mut writer = config.writer()?;
while let Some(batch) = receiver.blocking_recv() {
@@ -352,7 +351,7 @@ impl DataSink for StreamWrite {
writer.write(&batch)?;
}
Ok(count)
- }));
+ });
while let Some(b) = data.next().await.transpose()? {
if sender.send(b).await.is_err() {
@@ -360,6 +359,6 @@ impl DataSink for StreamWrite {
}
}
drop(sender);
- write.await.unwrap()
+ write_task.join().await.unwrap()
}
}
diff --git a/datafusion/core/src/execution/context/mod.rs
b/datafusion/core/src/execution/context/mod.rs
index ffc4a4f717..453a00a1a5 100644
--- a/datafusion/core/src/execution/context/mod.rs
+++ b/datafusion/core/src/execution/context/mod.rs
@@ -2288,6 +2288,7 @@ mod tests {
}
#[tokio::test]
+ #[allow(clippy::disallowed_methods)]
async fn send_context_to_threads() -> Result<()> {
// ensure SessionContexts can be used in a multi-threaded
// environment. Usecase is for concurrent planing.
diff --git a/datafusion/core/tests/fifo.rs b/datafusion/core/tests/fifo.rs
index 93c7f73680..c9ad95a3a0 100644
--- a/datafusion/core/tests/fifo.rs
+++ b/datafusion/core/tests/fifo.rs
@@ -103,6 +103,7 @@ mod unix_test {
let broken_pipe_timeout = Duration::from_secs(10);
let sa = file_path.clone();
// Spawn a new thread to write to the FIFO file
+ #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
spawn_blocking(move || {
let file = OpenOptions::new().write(true).open(sa).unwrap();
// Reference time to use when deciding to fail the test
@@ -357,6 +358,7 @@ mod unix_test {
(sink_fifo_path.clone(), sink_fifo_path.display());
// Spawn a new thread to read sink EXTERNAL TABLE.
+ #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
tasks.push(spawn_blocking(move || {
let file = File::open(sink_fifo_path_thread).unwrap();
let schema = Arc::new(Schema::new(vec![
diff --git
a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
index df6499e9b1..6c9c3359eb 100644
--- a/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/sort_preserving_repartition_fuzz.rs
@@ -302,6 +302,7 @@ mod sp_repartition_fuzz_tests {
let mut handles = Vec::new();
for seed in seed_start..seed_end {
+ #[allow(clippy::disallowed_methods)] // spawn allowed only
in tests
let job =
tokio::spawn(run_sort_preserving_repartition_test(
make_staggered_batches::<true>(n_row, n_distinct, seed
as u64),
is_first_roundrobin,
diff --git a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
index 609d26c9c2..1cab4d5c2f 100644
--- a/datafusion/core/tests/fuzz_cases/window_fuzz.rs
+++ b/datafusion/core/tests/fuzz_cases/window_fuzz.rs
@@ -123,6 +123,7 @@ async fn window_bounded_window_random_comparison() ->
Result<()> {
for i in 0..n {
let idx = i % test_cases.len();
let (pb_cols, ob_cols, search_mode) = test_cases[idx].clone();
+ #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
let job = tokio::spawn(run_window_test(
make_staggered_batches::<true>(1000, n_distinct, i as u64),
i as u64,
diff --git a/datafusion/physical-plan/src/common.rs
b/datafusion/physical-plan/src/common.rs
index e83dc2525b..5172bc9b2a 100644
--- a/datafusion/physical-plan/src/common.rs
+++ b/datafusion/physical-plan/src/common.rs
@@ -21,7 +21,6 @@ use std::fs;
use std::fs::{metadata, File};
use std::path::{Path, PathBuf};
use std::sync::Arc;
-use std::task::{Context, Poll};
use super::SendableRecordBatchStream;
use crate::stream::RecordBatchReceiverStream;
@@ -39,8 +38,7 @@ use datafusion_physical_expr::{PhysicalExpr,
PhysicalSortExpr};
use futures::{Future, StreamExt, TryStreamExt};
use parking_lot::Mutex;
-use pin_project_lite::pin_project;
-use tokio::task::JoinHandle;
+use tokio::task::{JoinError, JoinSet};
/// [`MemoryReservation`] used across query execution streams
pub(crate) type SharedMemoryReservation = Arc<Mutex<MemoryReservation>>;
@@ -174,50 +172,43 @@ pub fn compute_record_batch_statistics(
}
}
-pin_project! {
- /// Helper that aborts the given join handle on drop.
- ///
- /// Useful to kill background tasks when the consumer is dropped.
- #[derive(Debug)]
- pub struct AbortOnDropSingle<T>{
- #[pin]
- join_handle: JoinHandle<T>,
- }
-
- impl<T> PinnedDrop for AbortOnDropSingle<T> {
- fn drop(this: Pin<&mut Self>) {
- this.join_handle.abort();
- }
- }
+/// Helper that provides a simple API to spawn a single task and join it.
+/// Provides guarantees of aborting on `Drop` to keep it cancel-safe.
+///
+/// Technically, it's just a wrapper of `JoinSet` (with size=1).
+#[derive(Debug)]
+pub struct SpawnedTask<R> {
+ inner: JoinSet<R>,
}
-impl<T> AbortOnDropSingle<T> {
- /// Create new abort helper from join handle.
- pub fn new(join_handle: JoinHandle<T>) -> Self {
- Self { join_handle }
+impl<R: 'static> SpawnedTask<R> {
+ pub fn spawn<T>(task: T) -> Self
+ where
+ T: Future<Output = R>,
+ T: Send + 'static,
+ R: Send,
+ {
+ let mut inner = JoinSet::new();
+ inner.spawn(task);
+ Self { inner }
}
-}
-impl<T> Future for AbortOnDropSingle<T> {
- type Output = Result<T, tokio::task::JoinError>;
-
- fn poll(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) ->
Poll<Self::Output> {
- let this = self.project();
- this.join_handle.poll(cx)
+ pub fn spawn_blocking<T>(task: T) -> Self
+ where
+ T: FnOnce() -> R,
+ T: Send + 'static,
+ R: Send,
+ {
+ let mut inner = JoinSet::new();
+ inner.spawn_blocking(task);
+ Self { inner }
}
-}
-
-/// Helper that aborts the given join handles on drop.
-///
-/// Useful to kill background tasks when the consumer is dropped.
-#[derive(Debug)]
-pub struct AbortOnDropMany<T>(pub Vec<JoinHandle<T>>);
-impl<T> Drop for AbortOnDropMany<T> {
- fn drop(&mut self) {
- for join_handle in &self.0 {
- join_handle.abort();
- }
+ pub async fn join(mut self) -> Result<R, JoinError> {
+ self.inner
+ .join_next()
+ .await
+ .expect("`SpawnedTask` instance always contains exactly 1 task")
}
}
diff --git a/datafusion/physical-plan/src/lib.rs
b/datafusion/physical-plan/src/lib.rs
index 1c4a6ac0ec..562e42a7da 100644
--- a/datafusion/physical-plan/src/lib.rs
+++ b/datafusion/physical-plan/src/lib.rs
@@ -298,14 +298,14 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync {
/// "abort" such tasks, they may continue to consume resources even after
/// the plan is dropped, generating intermediate results that are never
/// used.
+ /// Thus, [`spawn`] is disallowed, and instead use [`SpawnedTask`].
///
- /// See [`AbortOnDropSingle`], [`AbortOnDropMany`] and
- /// [`RecordBatchReceiverStreamBuilder`] for structures to help ensure all
- /// background tasks are cancelled.
+ /// For more details see [`SpawnedTask`], [`JoinSet`] and
[`RecordBatchReceiverStreamBuilder`]
+ /// for structures to help ensure all background tasks are cancelled.
///
/// [`spawn`]: tokio::task::spawn
- /// [`AbortOnDropSingle`]: crate::common::AbortOnDropSingle
- /// [`AbortOnDropMany`]: crate::common::AbortOnDropMany
+ /// [`JoinSet`]: tokio::task::JoinSet
+ /// [`SpawnedTask`]: crate::common::SpawnedTask
/// [`RecordBatchReceiverStreamBuilder`]:
crate::stream::RecordBatchReceiverStreamBuilder
///
/// # Implementation Examples
diff --git a/datafusion/physical-plan/src/repartition/mod.rs
b/datafusion/physical-plan/src/repartition/mod.rs
index 07693f747f..a66a929796 100644
--- a/datafusion/physical-plan/src/repartition/mod.rs
+++ b/datafusion/physical-plan/src/repartition/mod.rs
@@ -32,21 +32,20 @@ use futures::{FutureExt, StreamExt};
use hashbrown::HashMap;
use log::trace;
use parking_lot::Mutex;
-use tokio::task::JoinHandle;
use datafusion_common::{arrow_datafusion_err, not_impl_err, DataFusionError,
Result};
use datafusion_execution::memory_pool::MemoryConsumer;
use datafusion_execution::TaskContext;
use datafusion_physical_expr::{EquivalenceProperties, PhysicalExpr};
-use crate::common::transpose;
+use crate::common::{transpose, SpawnedTask};
use crate::hash_utils::create_hashes;
use crate::metrics::BaselineMetrics;
use crate::repartition::distributor_channels::{channels,
partition_aware_channels};
use crate::sorts::streaming_merge;
use crate::{DisplayFormatType, ExecutionPlan, Partitioning, Statistics};
-use super::common::{AbortOnDropMany, AbortOnDropSingle,
SharedMemoryReservation};
+use super::common::SharedMemoryReservation;
use super::expressions::PhysicalSortExpr;
use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::{DisplayAs, RecordBatchStream, SendableRecordBatchStream};
@@ -74,7 +73,7 @@ struct RepartitionExecState {
>,
/// Helper that ensures that that background job is killed once it is no
longer needed.
- abort_helper: Arc<AbortOnDropMany<()>>,
+ abort_helper: Arc<Vec<SpawnedTask<()>>>,
}
/// A utility that can be used to partition batches based on [`Partitioning`]
@@ -522,7 +521,7 @@ impl ExecutionPlan for RepartitionExec {
}
// launch one async task per *input* partition
- let mut join_handles = Vec::with_capacity(num_input_partitions);
+ let mut spawned_tasks = Vec::with_capacity(num_input_partitions);
for i in 0..num_input_partitions {
let txs: HashMap<_, _> = state
.channels
@@ -534,28 +533,27 @@ impl ExecutionPlan for RepartitionExec {
let r_metrics = RepartitionMetrics::new(i, partition,
&self.metrics);
- let input_task: JoinHandle<Result<()>> =
- tokio::spawn(Self::pull_from_input(
- self.input.clone(),
- i,
- txs.clone(),
- self.partitioning.clone(),
- r_metrics,
- context.clone(),
- ));
+ let input_task = SpawnedTask::spawn(Self::pull_from_input(
+ self.input.clone(),
+ i,
+ txs.clone(),
+ self.partitioning.clone(),
+ r_metrics,
+ context.clone(),
+ ));
// In a separate task, wait for each input to be done
// (and pass along any errors, including panic!s)
- let join_handle = tokio::spawn(Self::wait_for_task(
- AbortOnDropSingle::new(input_task),
+ let wait_for_task = SpawnedTask::spawn(Self::wait_for_task(
+ input_task,
txs.into_iter()
.map(|(partition, (tx, _reservation))| (partition, tx))
.collect(),
));
- join_handles.push(join_handle);
+ spawned_tasks.push(wait_for_task);
}
- state.abort_helper = Arc::new(AbortOnDropMany(join_handles))
+ state.abort_helper = Arc::new(spawned_tasks)
}
trace!(
@@ -638,7 +636,7 @@ impl RepartitionExec {
partitioning,
state: Arc::new(Mutex::new(RepartitionExecState {
channels: HashMap::new(),
- abort_helper: Arc::new(AbortOnDropMany::<()>(vec![])),
+ abort_helper: Arc::new(Vec::new()),
})),
metrics: ExecutionPlanMetricsSet::new(),
preserve_order: false,
@@ -759,12 +757,13 @@ impl RepartitionExec {
/// complete. Upon error, propagates the errors to all output tx
/// channels.
async fn wait_for_task(
- input_task: AbortOnDropSingle<Result<()>>,
+ input_task: SpawnedTask<Result<()>>,
txs: HashMap<usize, DistributionSender<MaybeBatch>>,
) {
// wait for completion, and propagate error
// note we ignore errors on send (.ok) as that means the receiver has
already shutdown.
- match input_task.await {
+
+ match input_task.join().await {
// Error in joining task
Err(e) => {
let e = Arc::new(e);
@@ -813,7 +812,7 @@ struct RepartitionStream {
/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
- drop_helper: Arc<AbortOnDropMany<()>>,
+ drop_helper: Arc<Vec<SpawnedTask<()>>>,
/// Memory reservation.
reservation: SharedMemoryReservation,
@@ -877,7 +876,7 @@ struct PerPartitionStream {
/// Handle to ensure background tasks are killed when no longer needed.
#[allow(dead_code)]
- drop_helper: Arc<AbortOnDropMany<()>>,
+ drop_helper: Arc<Vec<SpawnedTask<()>>>,
/// Memory reservation.
reservation: SharedMemoryReservation,
@@ -1056,6 +1055,7 @@ mod tests {
}
#[tokio::test]
+ #[allow(clippy::disallowed_methods)]
async fn many_to_many_round_robin_within_tokio_task() -> Result<()> {
let join_handle: JoinHandle<Result<Vec<Vec<RecordBatch>>>> =
tokio::spawn(async move {
diff --git a/datafusion/physical-plan/src/sorts/sort.rs
b/datafusion/physical-plan/src/sorts/sort.rs
index 2d8237011f..84bf3ec415 100644
--- a/datafusion/physical-plan/src/sorts/sort.rs
+++ b/datafusion/physical-plan/src/sorts/sort.rs
@@ -27,7 +27,7 @@ use std::io::BufReader;
use std::path::{Path, PathBuf};
use std::sync::Arc;
-use crate::common::{spawn_buffered, IPCWriter};
+use crate::common::{spawn_buffered, IPCWriter, SpawnedTask};
use crate::expressions::PhysicalSortExpr;
use crate::metrics::{
BaselineMetrics, Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
@@ -56,7 +56,6 @@ use datafusion_physical_expr::EquivalenceProperties;
use futures::{StreamExt, TryStreamExt};
use log::{debug, error, trace};
use tokio::sync::mpsc::Sender;
-use tokio::task;
struct ExternalSorterMetrics {
/// metrics
@@ -604,8 +603,8 @@ async fn spill_sorted_batches(
schema: SchemaRef,
) -> Result<()> {
let path: PathBuf = path.into();
- let handle = task::spawn_blocking(move || write_sorted(batches, path,
schema));
- match handle.await {
+ let task = SpawnedTask::spawn_blocking(move || write_sorted(batches, path,
schema));
+ match task.join().await {
Ok(r) => r,
Err(e) => exec_err!("Error occurred while spilling {e}"),
}
diff --git a/datafusion/sqllogictest/bin/sqllogictests.rs
b/datafusion/sqllogictest/bin/sqllogictests.rs
index ffae144eae..41c33deec6 100644
--- a/datafusion/sqllogictest/bin/sqllogictests.rs
+++ b/datafusion/sqllogictest/bin/sqllogictests.rs
@@ -88,6 +88,7 @@ async fn run_tests() -> Result<()> {
// modifying shared state like `/tmp/`)
let errors: Vec<_> = futures::stream::iter(read_test_files(&options)?)
.map(|test_file| {
+ #[allow(clippy::disallowed_methods)] // spawn allowed only in tests
tokio::task::spawn(async move {
println!("Running {:?}", test_file.relative_path);
if options.complete {
diff --git a/dev/release/rat_exclude_files.txt
b/dev/release/rat_exclude_files.txt
index f99d6e15e8..ce5635b6da 100644
--- a/dev/release/rat_exclude_files.txt
+++ b/dev/release/rat_exclude_files.txt
@@ -136,4 +136,5 @@ datafusion/proto/src/generated/prost.rs
.github/ISSUE_TEMPLATE/feature_request.yml
.github/workflows/docs.yaml
**/node_modules/*
-datafusion/wasmtest/pkg/*
\ No newline at end of file
+datafusion/wasmtest/pkg/*
+clippy.toml