This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 581fd98270 refactor: add `join_unwind` to `SpawnedTask` (#9422)
581fd98270 is described below

commit 581fd98270e69221689ddd4c37566ef620a67167
Author: Artem Medvedev <[email protected]>
AuthorDate: Mon Mar 4 12:11:10 2024 +0100

    refactor: add `join_unwind` to `SpawnedTask` (#9422)
    
    * refactor: add `join_unwind` to `SpawnedTask`
    
    In order to remove duplication of these handlers it seems logical to have 
such method.
    
    I thought to add this logic to `join` but there are methods with additional 
logic
    
    * docs: improve join_unwind comments
    
    * docs: improve join_unwind comments
---
 clippy.toml                                        |  2 +-
 datafusion/common_runtime/src/common.rs            | 17 +++++
 .../core/src/datasource/file_format/arrow.rs       | 11 +---
 .../core/src/datasource/file_format/parquet.rs     | 75 +++++-----------------
 .../datasource/file_format/write/orchestration.rs  | 21 ++----
 datafusion/core/src/datasource/stream.rs           |  2 +-
 6 files changed, 44 insertions(+), 84 deletions(-)

diff --git a/clippy.toml b/clippy.toml
index 6eb9906c89..62d8263085 100644
--- a/clippy.toml
+++ b/clippy.toml
@@ -1,6 +1,6 @@
 disallowed-methods = [
     { path = "tokio::task::spawn", reason = "To provide cancel-safety, use 
`SpawnedTask::spawn` instead 
(https://github.com/apache/arrow-datafusion/issues/6513)" },
-    { path = "tokio::task::spawn_blocking", reason = "To provide 
cancel-safety, use `SpawnedTask::spawn` instead 
(https://github.com/apache/arrow-datafusion/issues/6513)" },
+    { path = "tokio::task::spawn_blocking", reason = "To provide 
cancel-safety, use `SpawnedTask::spawn_blocking` instead 
(https://github.com/apache/arrow-datafusion/issues/6513)" },
 ]
 
 disallowed-types = [
diff --git a/datafusion/common_runtime/src/common.rs 
b/datafusion/common_runtime/src/common.rs
index 88b74448c7..2f7ddb972f 100644
--- a/datafusion/common_runtime/src/common.rs
+++ b/datafusion/common_runtime/src/common.rs
@@ -51,10 +51,27 @@ impl<R: 'static> SpawnedTask<R> {
         Self { inner }
     }
 
+    /// Joins the task, returning the result of join (`Result<R, JoinError>`).
     pub async fn join(mut self) -> Result<R, JoinError> {
         self.inner
             .join_next()
             .await
             .expect("`SpawnedTask` instance always contains exactly 1 task")
     }
+
+    /// Joins the task and unwinds the panic if it happens.
+    pub async fn join_unwind(self) -> R {
+        self.join().await.unwrap_or_else(|e| {
+            // `JoinError` can be caused either by panic or cancellation. We 
have to handle panics:
+            if e.is_panic() {
+                std::panic::resume_unwind(e.into_panic());
+            } else {
+                // Cancellation may be caused by two reasons:
+                // 1. Abort is called, but since we consumed `self`, it's not 
our case (`JoinHandle` not accessible outside).
+                // 2. The runtime is shutting down.
+                // So we consider this branch as unreachable.
+                unreachable!("SpawnedTask was cancelled unexpectedly");
+            }
+        })
+    }
 }
diff --git a/datafusion/core/src/datasource/file_format/arrow.rs 
b/datafusion/core/src/datasource/file_format/arrow.rs
index d5f07d11be..90417a9781 100644
--- a/datafusion/core/src/datasource/file_format/arrow.rs
+++ b/datafusion/core/src/datasource/file_format/arrow.rs
@@ -295,16 +295,7 @@ impl DataSink for ArrowFileSink {
             }
         }
 
-        match demux_task.join().await {
-            Ok(r) => r?,
-            Err(e) => {
-                if e.is_panic() {
-                    std::panic::resume_unwind(e.into_panic());
-                } else {
-                    unreachable!();
-                }
-            }
-        }
+        demux_task.join_unwind().await?;
         Ok(row_count as u64)
     }
 }
diff --git a/datafusion/core/src/datasource/file_format/parquet.rs 
b/datafusion/core/src/datasource/file_format/parquet.rs
index 4ea6c2a273..3824177cb3 100644
--- a/datafusion/core/src/datasource/file_format/parquet.rs
+++ b/datafusion/core/src/datasource/file_format/parquet.rs
@@ -729,16 +729,7 @@ impl DataSink for ParquetSink {
             }
         }
 
-        match demux_task.join().await {
-            Ok(r) => r?,
-            Err(e) => {
-                if e.is_panic() {
-                    std::panic::resume_unwind(e.into_panic());
-                } else {
-                    unreachable!();
-                }
-            }
-        }
+        demux_task.join_unwind().await?;
 
         Ok(row_count as u64)
     }
@@ -831,19 +822,8 @@ fn spawn_rg_join_and_finalize_task(
         let num_cols = column_writer_tasks.len();
         let mut finalized_rg = Vec::with_capacity(num_cols);
         for task in column_writer_tasks.into_iter() {
-            match task.join().await {
-                Ok(r) => {
-                    let w = r?;
-                    finalized_rg.push(w.close()?);
-                }
-                Err(e) => {
-                    if e.is_panic() {
-                        std::panic::resume_unwind(e.into_panic())
-                    } else {
-                        unreachable!()
-                    }
-                }
-            }
+            let writer = task.join_unwind().await?;
+            finalized_rg.push(writer.close()?);
         }
 
         Ok((finalized_rg, rg_rows))
@@ -952,31 +932,21 @@ async fn concatenate_parallel_row_groups(
     let mut row_count = 0;
 
     while let Some(task) = serialize_rx.recv().await {
-        match task.join().await {
-            Ok(result) => {
-                let mut rg_out = parquet_writer.next_row_group()?;
-                let (serialized_columns, cnt) = result?;
-                row_count += cnt;
-                for chunk in serialized_columns {
-                    chunk.append_to_row_group(&mut rg_out)?;
-                    let mut buff_to_flush = 
merged_buff.buffer.try_lock().unwrap();
-                    if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
-                        object_store_writer
-                            .write_all(buff_to_flush.as_slice())
-                            .await?;
-                        buff_to_flush.clear();
-                    }
-                }
-                rg_out.close()?;
-            }
-            Err(e) => {
-                if e.is_panic() {
-                    std::panic::resume_unwind(e.into_panic());
-                } else {
-                    unreachable!();
-                }
+        let result = task.join_unwind().await;
+        let mut rg_out = parquet_writer.next_row_group()?;
+        let (serialized_columns, cnt) = result?;
+        row_count += cnt;
+        for chunk in serialized_columns {
+            chunk.append_to_row_group(&mut rg_out)?;
+            let mut buff_to_flush = merged_buff.buffer.try_lock().unwrap();
+            if buff_to_flush.len() > BUFFER_FLUSH_BYTES {
+                object_store_writer
+                    .write_all(buff_to_flush.as_slice())
+                    .await?;
+                buff_to_flush.clear();
             }
         }
+        rg_out.close()?;
     }
 
     let inner_writer = parquet_writer.into_inner()?;
@@ -1020,18 +990,7 @@ async fn output_single_parquet_file_parallelized(
     )
     .await?;
 
-    match launch_serialization_task.join().await {
-        Ok(Ok(_)) => (),
-        Ok(Err(e)) => return Err(e),
-        Err(e) => {
-            if e.is_panic() {
-                std::panic::resume_unwind(e.into_panic())
-            } else {
-                unreachable!()
-            }
-        }
-    }
-
+    launch_serialization_task.join_unwind().await?;
     Ok(row_count)
 }
 
diff --git a/datafusion/core/src/datasource/file_format/write/orchestration.rs 
b/datafusion/core/src/datasource/file_format/write/orchestration.rs
index dd0e5ce6a4..b7f2689593 100644
--- a/datafusion/core/src/datasource/file_format/write/orchestration.rs
+++ b/datafusion/core/src/datasource/file_format/write/orchestration.rs
@@ -34,7 +34,7 @@ use datafusion_common_runtime::SpawnedTask;
 use datafusion_execution::TaskContext;
 
 use bytes::Bytes;
-use futures::try_join;
+use futures::join;
 use tokio::io::{AsyncWrite, AsyncWriteExt};
 use tokio::sync::mpsc::{self, Receiver};
 use tokio::task::JoinSet;
@@ -264,19 +264,12 @@ pub(crate) async fn stateless_multipart_put(
     // Signal to the write coordinator that no more files are coming
     drop(tx_file_bundle);
 
-    match try_join!(write_coordinator_task.join(), demux_task.join()) {
-        Ok((r1, r2)) => {
-            r1?;
-            r2?;
-        }
-        Err(e) => {
-            if e.is_panic() {
-                std::panic::resume_unwind(e.into_panic());
-            } else {
-                unreachable!();
-            }
-        }
-    }
+    let (r1, r2) = join!(
+        write_coordinator_task.join_unwind(),
+        demux_task.join_unwind()
+    );
+    r1?;
+    r2?;
 
     let total_count = rx_row_cnt.await.map_err(|_| {
         internal_datafusion_err!("Did not receieve row count from write 
coordinater")
diff --git a/datafusion/core/src/datasource/stream.rs 
b/datafusion/core/src/datasource/stream.rs
index 0d91b1cba3..079c1a891d 100644
--- a/datafusion/core/src/datasource/stream.rs
+++ b/datafusion/core/src/datasource/stream.rs
@@ -359,6 +359,6 @@ impl DataSink for StreamWrite {
             }
         }
         drop(sender);
-        write_task.join().await.unwrap()
+        write_task.join_unwind().await
     }
 }

Reply via email to