This is an automated email from the ASF dual-hosted git repository.

avantgardner pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 05c0aaca Use StreamWriter instead of FileWriter (#943)
05c0aaca is described below

commit 05c0aacae2ed3543a4887fe252464e6c03b30639
Author: Brent Gardner <[email protected]>
AuthorDate: Thu Dec 21 09:12:44 2023 -0700

    Use StreamWriter instead of FileWriter (#943)
    
    Use StreamWriter instead of FileWriter (#943)
---
 .../core/src/execution_plans/shuffle_reader.rs     | 18 ++++----
 .../core/src/execution_plans/shuffle_writer.rs     | 52 +++++++++++++++-------
 ballista/core/src/utils.rs                         |  5 ++-
 ballista/executor/src/flight_service.rs            |  9 ++--
 4 files changed, 53 insertions(+), 31 deletions(-)

diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 6a77a16e..491e4d05 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -16,11 +16,13 @@
 // under the License.
 
 use async_trait::async_trait;
+use datafusion::arrow::ipc::reader::StreamReader;
 use datafusion::common::stats::Precision;
 use std::any::Any;
 use std::collections::HashMap;
 use std::fmt::Debug;
 use std::fs::File;
+use std::io::BufReader;
 use std::pin::Pin;
 use std::result;
 use std::sync::Arc;
@@ -31,7 +33,6 @@ use crate::serde::scheduler::{PartitionLocation, 
PartitionStats};
 
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::arrow::error::ArrowError;
-use datafusion::arrow::ipc::reader::FileReader;
 use datafusion::arrow::record_batch::RecordBatch;
 
 use datafusion::error::Result;
@@ -209,11 +210,11 @@ fn stats_for_partitions(
 }
 
 struct LocalShuffleStream {
-    reader: FileReader<File>,
+    reader: StreamReader<BufReader<File>>,
 }
 
 impl LocalShuffleStream {
-    pub fn new(reader: FileReader<File>) -> Self {
+    pub fn new(reader: StreamReader<BufReader<File>>) -> Self {
         LocalShuffleStream { reader }
     }
 }
@@ -412,13 +413,14 @@ async fn fetch_partition_local(
 
 fn fetch_partition_local_inner(
     path: &str,
-) -> result::Result<FileReader<File>, BallistaError> {
+) -> result::Result<StreamReader<BufReader<File>>, BallistaError> {
     let file = File::open(path).map_err(|e| {
         BallistaError::General(format!("Failed to open partition file at 
{path}: {e:?}"))
     })?;
-    FileReader::try_new(file, None).map_err(|e| {
+    let reader = StreamReader::try_new(file, None).map_err(|e| {
         BallistaError::General(format!("Failed to new arrow FileReader at 
{path}: {e:?}"))
-    })
+    })?;
+    Ok(reader)
 }
 
 async fn fetch_partition_object_store(
@@ -437,7 +439,7 @@ mod tests {
     use crate::utils;
     use datafusion::arrow::array::{Int32Array, StringArray, UInt32Array};
     use datafusion::arrow::datatypes::{DataType, Field, Schema};
-    use datafusion::arrow::ipc::writer::FileWriter;
+    use datafusion::arrow::ipc::writer::StreamWriter;
     use datafusion::arrow::record_batch::RecordBatch;
     use datafusion::common::DataFusionError;
     use datafusion::physical_expr::expressions::Column;
@@ -627,7 +629,7 @@ mod tests {
         let tmp_dir = tempdir().unwrap();
         let file_path = tmp_dir.path().join("shuffle_data");
         let file = File::create(&file_path).unwrap();
-        let mut writer = FileWriter::try_new(file, &schema).unwrap();
+        let mut writer = StreamWriter::try_new(file, &schema).unwrap();
         writer.write(&batch).unwrap();
         writer.finish().unwrap();
 
diff --git a/ballista/core/src/execution_plans/shuffle_writer.rs 
b/ballista/core/src/execution_plans/shuffle_writer.rs
index 2540a1d2..19d35477 100644
--- a/ballista/core/src/execution_plans/shuffle_writer.rs
+++ b/ballista/core/src/execution_plans/shuffle_writer.rs
@@ -24,7 +24,10 @@ use datafusion::arrow::ipc::writer::IpcWriteOptions;
 use datafusion::arrow::ipc::CompressionType;
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
 
+use datafusion::arrow::ipc::writer::StreamWriter;
 use std::any::Any;
+use std::fs;
+use std::fs::File;
 use std::future::Future;
 use std::iter::Iterator;
 use std::path::PathBuf;
@@ -42,7 +45,6 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, 
SchemaRef};
 
 use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::error::{DataFusionError, Result};
-use datafusion::physical_plan::common::IPCWriter;
 use datafusion::physical_plan::memory::MemoryStream;
 use datafusion::physical_plan::metrics::{
     self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet,
@@ -81,6 +83,13 @@ pub struct ShuffleWriterExec {
     metrics: ExecutionPlanMetricsSet,
 }
 
+pub struct WriteTracker {
+    pub num_batches: usize,
+    pub num_rows: usize,
+    pub writer: StreamWriter<File>,
+    pub path: PathBuf,
+}
+
 #[derive(Debug, Clone)]
 struct ShuffleWriteMetrics {
     /// Time spend writing batches to shuffle files
@@ -210,7 +219,7 @@ impl ShuffleWriterExec {
                 Some(Partitioning::Hash(exprs, num_output_partitions)) => {
                     // we won't necessary produce output for every possible 
partition, so we
                     // create writers on demand
-                    let mut writers: Vec<Option<IPCWriter>> = vec![];
+                    let mut writers: Vec<Option<WriteTracker>> = vec![];
                     for _ in 0..num_output_partitions {
                         writers.push(None);
                     }
@@ -232,7 +241,9 @@ impl ShuffleWriterExec {
                                 let timer = write_metrics.write_time.timer();
                                 match &mut writers[output_partition] {
                                     Some(w) => {
-                                        w.write(&output_batch)?;
+                                        w.num_batches += 1;
+                                        w.num_rows += output_batch.num_rows();
+                                        w.writer.write(&output_batch)?;
                                     }
                                     None => {
                                         let mut path = path.clone();
@@ -248,14 +259,22 @@ impl ShuffleWriterExec {
                                             .try_with_compression(Some(
                                                 CompressionType::LZ4_FRAME,
                                             ))?;
-                                        let mut writer = 
IPCWriter::new_with_options(
-                                            &path,
-                                            stream.schema().as_ref(),
-                                            options,
-                                        )?;
+
+                                        let file = File::create(path.clone())?;
+                                        let mut writer =
+                                            StreamWriter::try_new_with_options(
+                                                file,
+                                                stream.schema().as_ref(),
+                                                options,
+                                            )?;
 
                                         writer.write(&output_batch)?;
-                                        writers[output_partition] = 
Some(writer);
+                                        writers[output_partition] = 
Some(WriteTracker {
+                                            num_batches: 1,
+                                            num_rows: output_batch.num_rows(),
+                                            writer,
+                                            path,
+                                        });
                                     }
                                 }
                                 
write_metrics.output_rows.add(output_batch.num_rows());
@@ -270,22 +289,23 @@ impl ShuffleWriterExec {
                     for (i, w) in writers.iter_mut().enumerate() {
                         match w {
                             Some(w) => {
-                                w.finish()?;
+                                let num_bytes = fs::metadata(&w.path)?.len();
+                                w.writer.finish()?;
                                 debug!(
                                     "Finished writing shuffle partition {} at 
{:?}. Batches: {}. Rows: {}. Bytes: {}.",
                                     i,
-                                    w.path(),
+                                    w.path,
                                     w.num_batches,
                                     w.num_rows,
-                                    w.num_bytes
+                                    num_bytes
                                 );
 
                                 part_locs.push(ShuffleWritePartition {
                                     partition_id: i as u64,
-                                    path: 
w.path().to_string_lossy().to_string(),
-                                    num_batches: w.num_batches,
-                                    num_rows: w.num_rows,
-                                    num_bytes: w.num_bytes,
+                                    path: w.path.to_string_lossy().to_string(),
+                                    num_batches: w.num_batches as u64,
+                                    num_rows: w.num_rows as u64,
+                                    num_bytes,
                                 });
                             }
                             None => {}
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index a6541c08..45a4f53f 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -26,8 +26,9 @@ use crate::serde::scheduler::PartitionStats;
 use async_trait::async_trait;
 use datafusion::arrow::datatypes::Schema;
 use datafusion::arrow::ipc::writer::IpcWriteOptions;
+use datafusion::arrow::ipc::writer::StreamWriter;
 use datafusion::arrow::ipc::CompressionType;
-use datafusion::arrow::{ipc::writer::FileWriter, record_batch::RecordBatch};
+use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::datasource::physical_plan::{CsvExec, ParquetExec};
 use datafusion::error::DataFusionError;
 use datafusion::execution::context::{
@@ -89,7 +90,7 @@ pub async fn write_stream_to_disk(
         .try_with_compression(Some(CompressionType::LZ4_FRAME))?;
 
     let mut writer =
-        FileWriter::try_new_with_options(file, stream.schema().as_ref(), 
options)?;
+        StreamWriter::try_new_with_options(file, stream.schema().as_ref(), 
options)?;
 
     while let Some(result) = stream.next().await {
         let batch = result?;
diff --git a/ballista/executor/src/flight_service.rs 
b/ballista/executor/src/flight_service.rs
index ed8ffb35..c4ffab4f 100644
--- a/ballista/executor/src/flight_service.rs
+++ b/ballista/executor/src/flight_service.rs
@@ -17,6 +17,7 @@
 
 //! Implementation of the Apache Arrow Flight protocol that wraps an executor.
 
+use arrow::ipc::reader::StreamReader;
 use std::convert::TryFrom;
 use std::fs::File;
 use std::pin::Pin;
@@ -34,9 +35,7 @@ use arrow_flight::{
     FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, 
HandshakeResponse,
     PutResult, SchemaResult, Ticket,
 };
-use datafusion::arrow::{
-    error::ArrowError, ipc::reader::FileReader, record_batch::RecordBatch,
-};
+use datafusion::arrow::{error::ArrowError, record_batch::RecordBatch};
 use futures::{Stream, StreamExt, TryStreamExt};
 use log::{debug, info};
 use std::io::{Read, Seek};
@@ -97,7 +96,7 @@ impl FlightService for BallistaFlightService {
                     })
                     .map_err(|e| from_ballista_err(&e))?;
                 let reader =
-                    FileReader::try_new(file, None).map_err(|e| 
from_arrow_err(&e))?;
+                    StreamReader::try_new(file, None).map_err(|e| 
from_arrow_err(&e))?;
 
                 let (tx, rx) = channel(2);
                 let schema = reader.schema();
@@ -207,7 +206,7 @@ impl FlightService for BallistaFlightService {
 }
 
 fn read_partition<T>(
-    reader: FileReader<T>,
+    reader: StreamReader<std::io::BufReader<T>>,
     tx: Sender<Result<RecordBatch, FlightError>>,
 ) -> Result<(), FlightError>
 where

Reply via email to