This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 581fd98270 refactor: add `join_unwind` to `SpawnedTask` (#9422)
581fd98270 is described below
commit 581fd98270e69221689ddd4c37566ef620a67167
Author: Artem Medvedev <[email protected]>
AuthorDate: Mon Mar 4 12:11:10 2024 +0100
refactor: add `join_unwind` to `SpawnedTask` (#9422)
* refactor: add `join_unwind` to `SpawnedTask`
In order to remove duplication of these handlers it seems logical to have
such method.
I thought to add this logic to `join` but there are methods with additional
logic
* docs: improve join_unwind comments
* docs: improve join_unwind comments
---
clippy.toml | 2 +-
datafusion/common_runtime/src/common.rs | 17 +++++
.../core/src/datasource/file_format/arrow.rs | 11 +---
.../core/src/datasource/file_format/parquet.rs | 75 +++++-----------------
.../datasource/file_format/write/orchestration.rs | 21 ++----
datafusion/core/src/datasource/stream.rs | 2 +-
6 files changed, 44 insertions(+), 84 deletions(-)
diff --git a/clippy.toml b/clippy.toml
index 6eb9906c89..62d8263085 100644
--- a/clippy.toml
+++ b/clippy.toml
@@ -1,6 +1,6 @@
disallowed-methods = [
{ path = "tokio::task::spawn", reason = "To provide cancel-safety, use
`SpawnedTask::spawn` instead
(https://github.com/apache/arrow-datafusion/issues/6513)" },
- { path = "tokio::task::spawn_blocking", reason = "To provide
cancel-safety, use `SpawnedTask::spawn` instead
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+ { path = "tokio::task::spawn_blocking", reason = "To provide
cancel-safety, use `SpawnedTask::spawn_blocking` instead
(https://github.com/apache/arrow-datafusion/issues/6513)" },
]
disallowed-types = [
diff --git a/datafusion/common_runtime/src/common.rs
b/datafusion/common_runtime/src/common.rs
index 88b74448c7..2f7ddb972f 100644
--- a/datafusion/common_runtime/src/common.rs
+++ b/datafusion/common_runtime/src/common.rs
@@ -51,10 +51,27 @@ impl<R: 'static> SpawnedTask<R> {
Self { inner }
}
+ /// Joins the task, returning the result of join (`Result<R, JoinError>`).
pub async fn join(mut self) -> Result<R, JoinError> {
self.inner
.join_next()
.await
.expect("`SpawnedTask` instance always contains exactly 1 task")
}
+
+ /// Joins the task and unwinds the panic if it happens.
+ pub async fn join_unwind(self) -> R {
+ self.join().await.unwrap_or_else(|e| {
+ // `JoinError` can be caused either by panic or cancellation. We
have to handle panics:
+ if e.is_panic() {
+ std::panic::resume_unwind(e.into_panic());
+ } else {
+ // Cancellation may be caused by two reasons:
+ // 1. Abort is called, but since we consumed `self`, it's not
our case (`JoinHandle` not accessible outside).
+ // 2. The runtime is shutting down.
+ // So we consider this branch as unreachable.
+ unreachable!("SpawnedTask was cancelled unexpectedly");
+ }
+ })
+ }
}
diff --git a/datafusion/core/src/datasource/file_format/arrow.rs
b/datafusion/core/src/datasource/file_format/arrow.rs
index d5f07d11be..90417a9781 100644
--- a/datafusion/core/src/datasource/file_format/arrow.rs
+++ b/datafusion/core/src/datasource/file_format/arrow.rs
@@ -295,16 +295,7 @@ impl DataSink for ArrowFileSink {
}
}
- match demux_task.join().await {
- Ok(r) => r?,
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic());
- } else {
- unreachable!();
- }
- }
- }
+ demux_task.join_unwind().await?;
Ok(row_count as u64)
}
}
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs
b/datafusion/core/src/datasource/file_format/parquet.rs
index 4ea6c2a273..3824177cb3 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -729,16 +729,7 @@ impl DataSink for ParquetSink {
}
}
- match demux_task.join().await {
- Ok(r) => r?,
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic());
- } else {
- unreachable!();
- }
- }
- }
+ demux_task.join_unwind().await?;
Ok(row_count as u64)
}
@@ -831,19 +822,8 @@ fn spawn_rg_join_and_finalize_task(
let num_cols = column_writer_tasks.len();
let mut finalized_rg = Vec::with_capacity(num_cols);
for task in column_writer_tasks.into_iter() {
- match task.join().await {
- Ok(r) => {
- let w = r?;
- finalized_rg.push(w.close()?);
- }
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic())
- } else {
- unreachable!()
- }
- }
- }
+ let writer = task.join_unwind().await?;
+ finalized_rg.push(writer.close()?);
}
Ok((finalized_rg, rg_rows))
@@ -952,31 +932,21 @@ async fn concatenate_parallel_row_groups(
let mut row_count = 0;
while let Some(task) = serialize_rx.recv().await {
- match task.join().await {
- Ok(result) => {
- let mut rg_out = parquet_writer.next_row_group()?;
- let (serialized_columns, cnt) = result?;
- row_count += cnt;
- for chunk in serialized_columns {
- chunk.append_to_row_group(&mut rg_out)?;
- let mut buff_to_flush =
merged_buff.buffer.try_lock().unwrap();
- if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
- object_store_writer
- .write_all(buff_to_flush.as_slice())
- .await?;
- buff_to_flush.clear();
- }
- }
- rg_out.close()?;
- }
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic());
- } else {
- unreachable!();
- }
+ let result = task.join_unwind().await;
+ let mut rg_out = parquet_writer.next_row_group()?;
+ let (serialized_columns, cnt) = result?;
+ row_count += cnt;
+ for chunk in serialized_columns {
+ chunk.append_to_row_group(&mut rg_out)?;
+ let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
+ if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
+ object_store_writer
+ .write_all(buff_to_flush.as_slice())
+ .await?;
+ buff_to_flush.clear();
}
}
+ rg_out.close()?;
}
let inner_writer = parquet_writer.into_inner()?;
@@ -1020,18 +990,7 @@ async fn output_single_parquet_file_parallelized(
)
.await?;
- match launch_serialization_task.join().await {
- Ok(Ok(_)) => (),
- Ok(Err(e)) => return Err(e),
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic())
- } else {
- unreachable!()
- }
- }
- }
-
+ launch_serialization_task.join_unwind().await?;
Ok(row_count)
}
diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs
b/datafusion/core/src/datasource/file_format/write/orchestration.rs
index dd0e5ce6a4..b7f2689593 100644
--- a/datafusion/core/src/datasource/file_format/write/orchestration.rs
+++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs
@@ -34,7 +34,7 @@ use datafusion_common_runtime::SpawnedTask;
use datafusion_execution::TaskContext;
use bytes::Bytes;
-use futures::try_join;
+use futures::join;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::sync::mpsc::{self, Receiver};
use tokio::task::JoinSet;
@@ -264,19 +264,12 @@ pub(crate) async fn stateless_multipart_put(
// Signal to the write coordinator that no more files are coming
drop(tx_file_bundle);
- match try_join!(write_coordinator_task.join(), demux_task.join()) {
- Ok((r1, r2)) => {
- r1?;
- r2?;
- }
- Err(e) => {
- if e.is_panic() {
- std::panic::resume_unwind(e.into_panic());
- } else {
- unreachable!();
- }
- }
- }
+ let (r1, r2) = join!(
+ write_coordinator_task.join_unwind(),
+ demux_task.join_unwind()
+ );
+ r1?;
+ r2?;
let total_count = rx_row_cnt.await.map_err(|_| {
internal_datafusion_err!("Did not receieve row count from write
coordinater")
diff --git a/datafusion/core/src/datasource/stream.rs
b/datafusion/core/src/datasource/stream.rs
index 0d91b1cba3..079c1a891d 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -359,6 +359,6 @@ impl DataSink for StreamWrite {
}
}
drop(sender);
- write_task.join().await.unwrap()
+ write_task.join_unwind().await
}
}