metesynnada commented on code in PR #7452:
URL: https://github.com/apache/arrow-datafusion/pull/7452#discussion_r1318859784


##########
datafusion/core/src/datasource/file_format/write.rs:
##########
@@ -315,58 +300,237 @@ pub(crate) async fn create_writer(
     }
 }
 
+/// Serializes a single data stream in parallel and writes to an ObjectStore
+/// concurrently. Data order is preserved. In the event of an error,
+/// the ObjectStore writer is returned to the caller in addition to an error,
+/// so that the caller may handle aborting failed writes.
+async fn serialize_rb_stream_to_object_store(
+    mut data_stream: Pin<Box<dyn RecordBatchStream + Send>>,
+    mut serializer: Box<dyn BatchSerializer>,
+    mut writer: AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
+    unbounded_input: bool,
+) -> std::result::Result<
+    (
+        Box<dyn BatchSerializer>,
+        AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
+        u64,
+    ),
+    (
+        AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>,
+        DataFusionError,
+    ),
+> {
+    let (tx, mut rx) =
+        mpsc::channel::<JoinHandle<Result<(usize, Bytes), 
DataFusionError>>>(100);
+
+    let serialize_task = tokio::spawn(async move {
+        while let Some(maybe_batch) = data_stream.next().await {
+            match serializer.duplicate() {
+                Ok(mut serializer_clone) => {
+                    let handle = tokio::spawn(async move {
+                        let batch = maybe_batch?;
+                        let num_rows = batch.num_rows();
+                        let bytes = serializer_clone.serialize(batch).await?;
+                        Ok((num_rows, bytes))
+                    });
+                    tx.send(handle).await.map_err(|_| {
+                        DataFusionError::Internal(
+                            "Unknown error writing to object store".into(),
+                        )
+                    })?;
+                    if unbounded_input {
+                        tokio::task::yield_now().await;
+                    }
+                }
+                Err(_) => {
+                    return Err(DataFusionError::Internal(
+                        "Unknown error writing to object store".into(),
+                    ))
+                }
+            }
+        }
+        Ok(serializer)
+    });
+
+    let mut row_count = 0;
+    while let Some(handle) = rx.recv().await {
+        match handle.await {
+            Ok(Ok((cnt, bytes))) => {
+                match writer.write_all(&bytes).await {
+                    Ok(_) => (),
+                    Err(_) => {
+                        return Err((
+                            writer,
+                            DataFusionError::Internal(
+                                "Unknown error writing to object store".into(),
+                            ),
+                        ))
+                    }
+                };
+                row_count += cnt;
+            }
+            Ok(Err(e)) => {
+                // Return the writer along with the error
+                return Err((writer, e));
+            }
+            Err(_) => {
+                // Handle task panic or cancellation
+                return Err((
+                    writer,
+                    DataFusionError::Internal(
+                        "Serialization task panicked or was cancelled".into(),
+                    ),
+                ));
+            }
+        }
+    }
+
+    let serializer = match serialize_task.await {
+        Ok(Ok(serializer)) => serializer,
+        Ok(Err(e)) => return Err((writer, e)),
+        Err(_) => {
+            return Err((
+                writer,
+                DataFusionError::Internal("Unknown error writing to object 
store".into()),
+            ))
+        }
+    };
+    Ok((serializer, writer, row_count as u64))
+}
+
 /// Contains the common logic for serializing RecordBatches and
 /// writing the resulting bytes to an ObjectStore.
 /// Serialization is assumed to be stateless, i.e.
 /// each RecordBatch can be serialized without any
 /// dependency on the RecordBatches before or after.
 pub(crate) async fn stateless_serialize_and_write_files(
-    mut data: Vec<SendableRecordBatchStream>,
+    data: Vec<SendableRecordBatchStream>,
     mut serializers: Vec<Box<dyn BatchSerializer>>,
     mut writers: Vec<AbortableWrite<Box<dyn AsyncWrite + Send + Unpin>>>,
     single_file_output: bool,
+    unbounded_input: bool,
 ) -> Result<u64> {
     if single_file_output && (serializers.len() != 1 || writers.len() != 1) {
         return internal_err!("single_file_output is true, but got more than 1 
writer!");
     }
     let num_partitions = data.len();
-    if !single_file_output && (num_partitions != writers.len()) {
+    let num_writers = writers.len();
+    if !single_file_output && (num_partitions != num_writers) {
         return internal_err!("single_file_ouput is false, but did not get 1 
writer for each output partition!");
     }
     let mut row_count = 0;
-    // Map errors to DatafusionError.
-    let err_converter =
-        |_| DataFusionError::Internal("Unexpected FileSink Error".to_string());
-    // TODO parallelize serialization accross partitions and batches within 
partitions
-    // see: https://github.com/apache/arrow-datafusion/issues/7079
-    for (part_idx, data_stream) in 
data.iter_mut().enumerate().take(num_partitions) {
-        let idx = match single_file_output {
-            false => part_idx,
-            true => 0,
-        };
-        while let Some(maybe_batch) = data_stream.next().await {
-            // Write data to files in a round robin fashion:
-            let serializer = &mut serializers[idx];
-            let batch = check_for_errors(maybe_batch, &mut writers).await?;
-            row_count += batch.num_rows();
-            let bytes =
-                check_for_errors(serializer.serialize(batch).await, &mut 
writers).await?;
-            let writer = &mut writers[idx];
-            check_for_errors(
-                writer.write_all(&bytes).await.map_err(err_converter),
-                &mut writers,
-            )
-            .await?;
+    // tracks if any writers encountered an error triggering the need to abort
+    let mut any_errors = false;
+    // tracks the specific error triggering abort
+    let mut triggering_error = None;
+    // tracks if any errors were encountered in the process of aborting 
writers.
+    // if true, we may not have a guarentee that all written data was cleaned 
up.
+    let mut any_abort_errors = false;
+    match single_file_output {
+        false => {
+            let mut join_set = JoinSet::new();
+            for (data_stream, serializer, writer) in data
+                .into_iter()
+                .zip(serializers.into_iter())
+                .zip(writers.into_iter())
+                .map(|((a, b), c)| (a, b, c))
+            {
+                join_set.spawn(async move {
+                    serialize_rb_stream_to_object_store(
+                        data_stream,
+                        serializer,
+                        writer,
+                        unbounded_input,
+                    )
+                    .await
+                });
+            }
+            let mut finished_writers = Vec::with_capacity(num_writers);
+            while let Some(result) = join_set.join_next().await {
+                match result {
+                    Ok(res) => match res {
+                        Ok((_, writer, cnt)) => {
+                            finished_writers.push(writer);
+                            row_count += cnt;
+                        }
+                        Err((writer, e)) => {
+                            finished_writers.push(writer);
+                            any_errors = true;
+                            triggering_error = Some(e);
+                        }
+                    },
+                    Err(_) => {
+                        // Don't panic, instead try to clean up as many 
writers as possible.
+                        // If we hit this code, ownership of a writer was not 
joined back to
+                        // this thread, so we cannot clean it up (hence 
any_abort_errors is true)
+                        any_errors = true;
+                        any_abort_errors = true;
+                    }
+                }
+            }
+
+            // Finalize or abort writers as appropriate
+            for mut writer in finished_writers.into_iter() {
+                match any_errors {
+                    true => {
+                        let abort_result = writer.abort_writer();
+                        if abort_result.is_err() {
+                            any_abort_errors = true;
+                        }
+                    }
+                    false => {
+                        // TODO if we encounter an error during shutdown, 
delete previously written files?
+                        writer.shutdown()
+                            .await
+                            .map_err(|_| DataFusionError::Internal("Error 
encountered while finalizing writes! Partial results may have been written to 
ObjectStore!".into()))?;
+                    }
+                }
+            }
+        }
+        true => {
+            let mut writer = writers.remove(0);
+            let mut serializer = serializers.remove(0);
+            let mut cnt;
+            for data_stream in data.into_iter() {
+                (serializer, writer, cnt) = match 
serialize_rb_stream_to_object_store(
+                    data_stream,
+                    serializer,
+                    writer,
+                    unbounded_input,
+                )
+                .await
+                {
+                    Ok((s, w, c)) => (s, w, c),

Review Comment:
   Comment on these types of declarations. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to