This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new f305a76e feat: Improve Remote Shuffle Read Speed and Resource 
Utilisation  (#1318)
f305a76e is described below

commit f305a76e1e0da0dbf80c6bd6d9f033e1845b3a5e
Author: Marko Milenković <milenkov...@users.noreply.github.com>
AuthorDate: Sat Sep 13 19:03:11 2025 +0100

    feat: Improve Remote Shuffle Read Speed and Resource Utilisation  (#1318)
    
    * add block stream implementation
    
    * integrate do_action
    
    * add flight server action
    
    * add config option
    
    * fix action name
    
    * fix clippy issues
    
    * add config test
    
    * remove drop
    
    * improve messages
    
    * minor comment update
    
    * minor comment
    
    * fix clippy
---
 Cargo.lock                                         |   1 +
 ballista/client/tests/context_checks.rs            |  63 ++++
 ballista/core/src/client.rs                        | 390 ++++++++++++++++++++-
 ballista/core/src/config.rs                        |  13 +
 .../core/src/execution_plans/distributed_query.rs  |   4 +-
 .../core/src/execution_plans/shuffle_reader.rs     |  22 +-
 ballista/core/src/extension.rs                     |  31 +-
 ballista/executor/Cargo.toml                       |   1 +
 ballista/executor/src/flight_service.rs            |  64 +++-
 9 files changed, 570 insertions(+), 19 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 58d3ae03..4e721e80 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1036,6 +1036,7 @@ dependencies = [
  "tempfile",
  "tokio",
  "tokio-stream",
+ "tokio-util",
  "tonic",
  "tracing",
  "tracing-appender",
diff --git a/ballista/client/tests/context_checks.rs 
b/ballista/client/tests/context_checks.rs
index a4392f77..638add1c 100644
--- a/ballista/client/tests/context_checks.rs
+++ b/ballista/client/tests/context_checks.rs
@@ -872,4 +872,67 @@ mod supported {
 
         Ok(())
     }
+
+    #[rstest]
+    #[case::standalone(standalone_context())]
+    #[case::remote(remote_context())]
+    #[tokio::test]
+    async fn should_force_local_read_with_flight(
+        #[future(awt)]
+        #[case]
+        ctx: SessionContext,
+        test_data: String,
+    ) -> datafusion::error::Result<()> {
+        ctx.register_parquet(
+            "test",
+            &format!("{test_data}/alltypes_plain.parquet"),
+            Default::default(),
+        )
+        .await?;
+
+        ctx.sql("SET ballista.shuffle.force_remote_read = true")
+            .await?
+            .show()
+            .await?;
+
+        ctx.sql("SET ballista.shuffle.remote_read_prefer_flight = true")
+            .await?
+            .show()
+            .await?;
+
+        let result = ctx
+            .sql("select name, value from information_schema.df_settings where 
name like 'ballista.shuffle.remote_read_prefer_flight' order by name limit 1")
+            .await?
+            .collect()
+            .await?;
+
+        let expected = [
+            "+--------------------------------------------+-------+",
+            "| name                                       | value |",
+            "+--------------------------------------------+-------+",
+            "| ballista.shuffle.remote_read_prefer_flight | true  |",
+            "+--------------------------------------------+-------+",
+        ];
+
+        assert_batches_eq!(expected, &result);
+
+        let expected = [
+            "+------------+----------+",
+            "| string_col | count(*) |",
+            "+------------+----------+",
+            "| 30         | 1        |",
+            "| 31         | 2        |",
+            "+------------+----------+",
+        ];
+
+        let result = ctx
+            .sql("select string_col, count(*) from test where id > 4 group by 
string_col order by string_col")
+            .await?
+            .collect()
+            .await?;
+
+        assert_batches_eq!(expected, &result);
+
+        Ok(())
+    }
 }
diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs
index c1da64ac..528570f9 100644
--- a/ballista/core/src/client.rs
+++ b/ballista/core/src/client.rs
@@ -25,7 +25,7 @@ use std::{
     task::{Context, Poll},
 };
 
-use crate::error::{BallistaError, Result};
+use crate::error::{BallistaError, Result as BResult};
 use crate::serde::scheduler::{Action, PartitionId};
 
 use arrow_flight;
@@ -33,12 +33,16 @@ use arrow_flight::utils::flight_data_to_arrow_batch;
 use arrow_flight::Ticket;
 use arrow_flight::{flight_service_client::FlightServiceClient, FlightData};
 use datafusion::arrow::array::ArrayRef;
+use datafusion::arrow::buffer::{Buffer, MutableBuffer};
+use datafusion::arrow::ipc::convert::try_schema_from_ipc_buffer;
+use datafusion::arrow::ipc::reader::StreamDecoder;
 use datafusion::arrow::{
     datatypes::{Schema, SchemaRef},
     error::ArrowError,
     record_batch::RecordBatch,
 };
 use datafusion::error::DataFusionError;
+use datafusion::error::Result;
 
 use crate::serde::protobuf;
 use crate::utils::create_grpc_client_connection;
@@ -61,7 +65,11 @@ const IO_RETRY_WAIT_TIME_MS: u64 = 3000;
 impl BallistaClient {
     /// Create a new BallistaClient to connect to the executor listening on 
the specified
     /// host and port
-    pub async fn try_new(host: &str, port: u16, max_message_size: usize) -> 
Result<Self> {
+    pub async fn try_new(
+        host: &str,
+        port: u16,
+        max_message_size: usize,
+    ) -> BResult<Self> {
         let addr = format!("http://{host}:{port}";);
         debug!("BallistaClient connecting to {addr}");
         let connection =
@@ -81,7 +89,11 @@ impl BallistaClient {
         Ok(Self { flight_client })
     }
 
-    /// Fetch a partition from an executor
+    /// Retrieves a partition from an executor.
+    ///
+    /// Depending on the value of the `flight_transport` parameter, this 
method will utilize either
+    /// the Arrow Flight protocol for compatibility, or a more efficient 
block-based transfer mechanism.
+    /// The block-based transfer is optimized for performance and reduces 
computational overhead on the server.
     pub async fn fetch_partition(
         &mut self,
         executor_id: &str,
@@ -89,7 +101,8 @@ impl BallistaClient {
         path: &str,
         host: &str,
         port: u16,
-    ) -> Result<SendableRecordBatchStream> {
+        flight_transport: bool,
+    ) -> BResult<SendableRecordBatchStream> {
         let action = Action::FetchPartition {
             job_id: partition_id.job_id.clone(),
             stage_id: partition_id.stage_id,
@@ -98,8 +111,14 @@ impl BallistaClient {
             host: host.to_owned(),
             port,
         };
-        self.execute_action(&action)
-            .await
+
+        let result = if flight_transport {
+            self.execute_do_get(&action).await
+        } else {
+            self.execute_do_action(&action).await
+        };
+
+        result
             .map_err(|error| match error {
                 // map grpc connection error to partition fetch error.
                 BallistaError::GrpcActionError(msg) => {
@@ -122,11 +141,15 @@ impl BallistaClient {
             })
     }
 
-    /// Execute an action and retrieve the results
-    pub async fn execute_action(
+    /// Executes the specified action and retrieves the results from the 
remote executor.
+    ///
+    /// This method establishes a [FlightDataStream] to facilitate the 
transfer of data
+    /// using the Arrow Flight protocol. The [FlightDataStream] handles the 
streaming
+    /// of record batches from the server to the client in an efficient and 
structured manner.
+    pub async fn execute_do_get(
         &mut self,
         action: &Action,
-    ) -> Result<SendableRecordBatchStream> {
+    ) -> BResult<SendableRecordBatchStream> {
         let serialized_action: protobuf::Action = 
action.to_owned().try_into()?;
 
         let mut buf: Vec<u8> = 
Vec::with_capacity(serialized_action.encoded_len());
@@ -197,8 +220,85 @@ impl BallistaClient {
         }
         unreachable!("Did not receive schema batch from flight server");
     }
+
+    /// Executes the specified action and retrieves the results from the 
remote executor
+    /// using an optimized block-based transfer operation. This method 
establishes a
+    /// [BlockDataStream] to facilitate efficient transmission of data blocks, 
reducing
+    /// computational overhead and improving performance compared to flight 
protocols.
+    pub async fn execute_do_action(
+        &mut self,
+        action: &Action,
+    ) -> BResult<SendableRecordBatchStream> {
+        let serialized_action: protobuf::Action = 
action.to_owned().try_into()?;
+
+        let mut buf: Vec<u8> = 
Vec::with_capacity(serialized_action.encoded_len());
+
+        serialized_action
+            .encode(&mut buf)
+            .map_err(|e| BallistaError::GrpcActionError(format!("{e:?}")))?;
+
+        for i in 0..IO_RETRIES_TIMES {
+            if i > 0 {
+                warn!(
+                    "Remote shuffle read fail, retry {i} times, sleep 
{IO_RETRY_WAIT_TIME_MS} ms."
+                );
+                tokio::time::sleep(std::time::Duration::from_millis(
+                    IO_RETRY_WAIT_TIME_MS,
+                ))
+                .await;
+            }
+
+            let request = tonic::Request::new(arrow_flight::Action {
+                body: buf.clone().into(),
+                r#type: "IO_BLOCK_TRANSPORT".to_string(),
+            });
+            let result = self.flight_client.do_action(request).await;
+            let res = match result {
+                Ok(res) => res,
+                Err(ref err) => {
+                    // IO related error like connection timeout, reset... will 
warp with Code::Unknown
+                    // This means IO related error will retry.
+                    if i == IO_RETRIES_TIMES - 1 || err.code() != 
Code::Unknown {
+                        return BallistaError::GrpcActionError(format!(
+                            "{:?}",
+                            result.unwrap_err()
+                        ))
+                        .into();
+                    }
+                    // retry request
+                    continue;
+                }
+            };
+
+            let stream = res.into_inner();
+            let stream = stream.map(|m| {
+                m.map(|b| b.body).map_err(|e| {
+                    DataFusionError::ArrowError(
+                        Box::new(ArrowError::IpcError(e.to_string())),
+                        None,
+                    )
+                })
+            });
+
+            return Ok(Box::pin(BlockDataStream::try_new(stream).await?));
+        }
+        unreachable!("Did not receive schema batch from flight server");
+    }
 }
 
+/// [FlightDataStream] facilitates the transfer of shuffle data using the 
Arrow Flight protocol.
+/// Internally, it invokes the `do_get` method on the Arrow Flight server, 
which returns a stream
+/// of messages, each representing a record batch.
+///
+/// The Flight server is responsible for decompressing and decoding the 
shuffle file, and then
+/// transmitting each batch as an individual message. Each message is 
compressed independently.
+///
+/// This approach increases the computational load on the Flight server due to 
repeated
+/// decompression and compression operations. Furthermore, compression 
efficiency is reduced
+/// compared to file-level compression, as it operates on smaller data 
segments.
+///
+/// For further discussion regarding performance implications, refer to:
+/// https://github.com/apache/datafusion-ballista/issues/1315
 struct FlightDataStream {
     stream: Streaming<FlightData>,
     schema: SchemaRef,
@@ -246,3 +346,275 @@ impl RecordBatchStream for FlightDataStream {
         self.schema.clone()
     }
 }
+/// [BlockDataStream] facilitates the transfer of original shuffle files in a 
block-by-block manner.
+/// This implementation utilizes a custom `do_action` method on the Arrow 
Flight server.
+/// The primary distinction from [FlightDataStream] is that it does not 
decompress or decode
+/// the original partition file on the server side. This approach reduces 
computational overhead
+/// on the Flight server and enables the transmission of less data, owing to 
improved file-level compression.
+///
+/// For a detailed discussion of the performance advantages, see:
+/// https://github.com/apache/datafusion-ballista/issues/1315
+pub struct BlockDataStream<S: Stream<Item = Result<prost::bytes::Bytes>> + 
Unpin> {
+    decoder: StreamDecoder,
+    state_buffer: Buffer,
+    ipc_stream: S,
+    transmitted: usize,
+    pub schema: SchemaRef,
+}
+
+/// maximum length of message with schema definition
+const MAXIMUM_SCHEMA_BUFFER_SIZE: usize = 8_388_608;
+
+impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> BlockDataStream<S> 
{
+    pub async fn try_new(
+        mut ipc_stream: S,
+    ) -> std::result::Result<Self, DataFusionError> {
+        let mut state_buffer = Buffer::default();
+
+        loop {
+            if state_buffer.len() > MAXIMUM_SCHEMA_BUFFER_SIZE {
+                return Err(ArrowError::IpcError(format!(
+                    "Schema buffer length exceeded maximum buffer size, 
expected {} actual: {}",
+                    MAXIMUM_SCHEMA_BUFFER_SIZE,
+                    state_buffer.len()
+                )).into());
+            }
+
+            match ipc_stream.next().await {
+                Some(Ok(blob)) => {
+                    state_buffer =
+                        Self::combine_buffers(&state_buffer, 
&Buffer::from(blob));
+
+                    match try_schema_from_ipc_buffer(state_buffer.as_slice()) {
+                        Ok(schema) => {
+                            return Ok(Self {
+                                decoder: StreamDecoder::new(),
+                                transmitted: state_buffer.len(),
+                                state_buffer,
+                                ipc_stream,
+                                schema: Arc::new(schema),
+                            });
+                        }
+                        Err(ArrowError::ParseError(_)) => {
+                            //
+                            // parse errors are ignored as may have not 
received whole message
+                            // thus schema may not be extracted
+                            //
+                        }
+                        Err(e) => return Err(e.into()),
+                    }
+                }
+                Some(Err(e)) => return 
Err(ArrowError::IpcError(e.to_string()).into()),
+                None => {
+                    return Err(ArrowError::IpcError(
+                        "Premature end of the stream while decoding 
schema".to_owned(),
+                    )
+                    .into());
+                }
+            }
+        }
+    }
+}
+
+impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> BlockDataStream<S> 
{
+    fn combine_buffers(first: &Buffer, second: &Buffer) -> Buffer {
+        let mut combined = MutableBuffer::new(first.len() + second.len());
+        combined.extend_from_slice(first.as_slice());
+        combined.extend_from_slice(second.as_slice());
+        combined.into()
+    }
+
+    fn decode(&mut self) -> std::result::Result<Option<RecordBatch>, 
ArrowError> {
+        self.decoder.decode(&mut self.state_buffer)
+    }
+
+    fn extend_bytes(&mut self, blob: prost::bytes::Bytes) {
+        //
+        //TODO: do we want to limit maximum buffer size here as well?
+        //
+        self.transmitted += blob.len();
+        self.state_buffer = Self::combine_buffers(&self.state_buffer, 
&Buffer::from(blob))
+    }
+}
+
+impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> Stream
+    for BlockDataStream<S>
+{
+    type Item = datafusion::error::Result<RecordBatch>;
+
+    fn poll_next(
+        mut self: std::pin::Pin<&mut Self>,
+        cx: &mut std::task::Context<'_>,
+    ) -> std::task::Poll<Option<Self::Item>> {
+        match self.decode() {
+            //
+            // if there is a batch to be read from state buffer return it
+            //
+            Ok(Some(batch)) => std::task::Poll::Ready(Some(Ok(batch))),
+            //
+            // there is no batch in the state buffer, try to pull new data
+            // from remote ipc decode it try to return next batch
+            //
+            Ok(None) => match self.ipc_stream.poll_next_unpin(cx) {
+                std::task::Poll::Ready(Some(flight_data_result)) => {
+                    match flight_data_result {
+                        Ok(blob) => {
+                            self.extend_bytes(blob);
+
+                            match self.decode() {
+                                Ok(Some(batch)) => {
+                                    std::task::Poll::Ready(Some(Ok(batch)))
+                                }
+                                Ok(None) => {
+                                    cx.waker().wake_by_ref();
+                                    std::task::Poll::Pending
+                                }
+                                Err(e) => std::task::Poll::Ready(Some(Err(
+                                    ArrowError::IpcError(e.to_string()).into(),
+                                ))),
+                            }
+                        }
+                        Err(e) => std::task::Poll::Ready(Some(Err(
+                            ArrowError::IpcError(e.to_string()).into(),
+                        ))),
+                    }
+                }
+                //
+                // end of IPC stream
+                //
+                std::task::Poll::Ready(None) => std::task::Poll::Ready(None),
+                // its expected that underlying stream will register waker 
callback
+                std::task::Poll::Pending => std::task::Poll::Pending,
+            },
+            Err(e) => std::task::Poll::Ready(Some(Err(ArrowError::IpcError(
+                e.to_string(),
+            )
+            .into()))),
+        }
+    }
+}
+
+impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> RecordBatchStream
+    for BlockDataStream<S>
+{
+    fn schema(&self) -> SchemaRef {
+        self.schema.clone()
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use std::sync::Arc;
+
+    use datafusion::arrow::{
+        array::{DictionaryArray, Int32Array, RecordBatch},
+        datatypes::Int32Type,
+        ipc::writer::StreamWriter,
+    };
+    use futures::{StreamExt, TryStreamExt};
+    use prost::bytes::Bytes;
+
+    use crate::client::BlockDataStream;
+
+    fn generate_batches() -> Vec<RecordBatch> {
+        let batch0 = RecordBatch::try_from_iter([
+            ("a", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
+            (
+                "b",
+                Arc::new(Int32Array::from(vec![11, 22, 33, 44, 55])) as _,
+            ),
+            (
+                "c",
+                Arc::new(DictionaryArray::<Int32Type>::from_iter([
+                    "hello", "hello", "world", "some", "other",
+                ])) as _,
+            ),
+        ])
+        .unwrap();
+
+        let batch1 = RecordBatch::try_from_iter([
+            (
+                "a",
+                Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50])) as _,
+            ),
+            (
+                "b",
+                Arc::new(Int32Array::from(vec![110, 220, 330, 440, 550])) as _,
+            ),
+            (
+                "c",
+                Arc::new(DictionaryArray::<Int32Type>::from_iter([
+                    "hello", "some", "world", "some", "other",
+                ])) as _,
+            ),
+        ])
+        .unwrap();
+
+        vec![batch0, batch1]
+    }
+
+    fn generate_ipc_stream(batches: &[RecordBatch]) -> Vec<u8> {
+        let mut result = vec![];
+        let mut writer =
+            StreamWriter::try_new(&mut result, &batches[0].schema()).unwrap();
+        for b in batches {
+            writer.write(b).unwrap();
+        }
+
+        writer.finish().unwrap();
+        result
+    }
+
+    #[tokio::test]
+    async fn should_process_chunked() {
+        let batches = generate_batches();
+        let ipc_blob = generate_ipc_stream(&batches);
+        let stream = futures::stream::iter(ipc_blob)
+            .chunks(2)
+            .map(|b| Ok(Bytes::from(b)));
+
+        let result: datafusion::error::Result<Vec<RecordBatch>> =
+            BlockDataStream::try_new(stream)
+                .await
+                .unwrap()
+                .try_collect()
+                .await;
+
+        assert_eq!(batches, result.unwrap())
+    }
+
+    #[tokio::test]
+    async fn should_process_single_message() {
+        let batches = generate_batches();
+        let blob = generate_ipc_stream(&batches);
+        let stream = futures::stream::iter(vec![Ok(Bytes::from(blob))]);
+
+        let result: datafusion::error::Result<Vec<RecordBatch>> =
+            BlockDataStream::try_new(stream)
+                .await
+                .unwrap()
+                .try_collect()
+                .await;
+
+        assert_eq!(batches, result.unwrap())
+    }
+
+    #[tokio::test]
+    #[should_panic = "Premature end of the stream while decoding schema"]
+    async fn should_process_panic_if_not_correct_stream() {
+        let batches = generate_batches();
+        let ipc_blob = generate_ipc_stream(&batches);
+        let stream = futures::stream::iter(ipc_blob[..5].to_vec())
+            .chunks(2)
+            .map(|b| Ok(Bytes::from(b)));
+
+        let result: datafusion::error::Result<Vec<RecordBatch>> =
+            BlockDataStream::try_new(stream)
+                .await
+                .unwrap()
+                .try_collect()
+                .await;
+
+        assert_eq!(batches, result.unwrap())
+    }
+}
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 823d1d16..ff7766de 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -36,6 +36,8 @@ pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str =
     "ballista.shuffle.max_concurrent_read_requests";
 pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
     "ballista.shuffle.force_remote_read";
+pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
+    "ballista.shuffle.remote_read_prefer_flight";
 
 pub type ParseResult<T> = result::Result<T, String>;
 use std::sync::LazyLock;
@@ -60,6 +62,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
                          "Forces the shuffle reader to always read partitions 
via the Arrow Flight client, even when partitions are local to the 
node.".to_string(),
                          DataType::Boolean,
                          Some((false).to_string())),
+        
ConfigEntry::new(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT.to_string(),
+                         "Forces the shuffle reader to use flight reader 
instead of block reader for remote read. Block reader usually has better 
performance and resource utilization".to_string(),
+                         DataType::Boolean,
+                         Some((false).to_string())),
 
     ];
     entries
@@ -191,6 +197,13 @@ impl BallistaConfig {
     pub fn shuffle_reader_force_remote_read(&self) -> bool {
         self.get_bool_setting(BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ)
     }
+    /// Forces the shuffle reader to prefer flight protocol over block protocol
+    /// to read remote shuffle partition.
+    ///
+    /// Block protocol is usually more performant than flight protocol
+    pub fn shuffle_reader_remote_prefer_flight(&self) -> bool {
+        self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT)
+    }
 
     fn get_usize_setting(&self, key: &str) -> usize {
         if let Some(v) = self.settings.get(key) {
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index c352c987..eab3c0d0 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -357,7 +357,7 @@ async fn execute_query(
 
                 info!("Job {job_id} finished executing in {duration:?} ");
                 let streams = partition_location.into_iter().map(move 
|partition| {
-                    let f = fetch_partition(partition, max_message_size)
+                    let f = fetch_partition(partition, max_message_size, true)
                         .map_err(|e| ArrowError::ExternalError(Box::new(e)));
 
                     futures::stream::once(f).try_flatten()
@@ -372,6 +372,7 @@ async fn execute_query(
 async fn fetch_partition(
     location: PartitionLocation,
     max_message_size: usize,
+    flight_transport: bool,
 ) -> Result<SendableRecordBatchStream> {
     let metadata = location.executor_meta.ok_or_else(|| {
         DataFusionError::Internal("Received empty executor 
metadata".to_owned())
@@ -391,6 +392,7 @@ async fn fetch_partition(
             &location.path,
             host,
             port,
+            flight_transport,
         )
         .await
         .map_err(|e| DataFusionError::External(Box::new(e)))
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index 53624d17..617654ec 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -161,6 +161,7 @@ impl ExecutionPlan for ShuffleReaderExec {
             config.ballista_shuffle_reader_maximum_concurrent_requests();
         let max_message_size = config.ballista_grpc_client_max_message_size();
         let force_remote_read = 
config.ballista_shuffle_reader_force_remote_read();
+        let prefer_flight = 
config.ballista_shuffle_reader_remote_prefer_flight();
 
         if force_remote_read {
             debug!(
@@ -193,6 +194,7 @@ impl ExecutionPlan for ShuffleReaderExec {
             max_request_num,
             max_message_size,
             force_remote_read,
+            prefer_flight,
         );
 
         let result = RecordBatchStreamAdapter::new(
@@ -386,6 +388,7 @@ fn send_fetch_partitions(
     max_request_num: usize,
     max_message_size: usize,
     force_remote_read: bool,
+    flight_transport: bool,
 ) -> AbortableReceiverStream {
     let (response_sender, response_receiver) = mpsc::channel(max_request_num);
     let semaphore = Arc::new(Semaphore::new(max_request_num));
@@ -405,7 +408,7 @@ fn send_fetch_partitions(
     spawned_tasks.push(SpawnedTask::spawn(async move {
         for p in local_locations {
             let r = PartitionReaderEnum::Local
-                .fetch_partition(&p, max_message_size)
+                .fetch_partition(&p, max_message_size, flight_transport)
                 .await;
             if let Err(e) = response_sender_c.send(r).await {
                 error!("Fail to send response event to the channel due to 
{e}");
@@ -420,7 +423,7 @@ fn send_fetch_partitions(
             // Block if exceeds max request number.
             let permit = semaphore.acquire_owned().await.unwrap();
             let r = PartitionReaderEnum::FlightRemote
-                .fetch_partition(&p, max_message_size)
+                .fetch_partition(&p, max_message_size, flight_transport)
                 .await;
             // Block if the channel buffer is full.
             if let Err(e) = response_sender.send(r).await {
@@ -446,6 +449,7 @@ trait PartitionReader: Send + Sync + Clone {
         &self,
         location: &PartitionLocation,
         max_message_size: usize,
+        flight_transport: bool,
     ) -> result::Result<SendableRecordBatchStream, BallistaError>;
 }
 
@@ -464,10 +468,11 @@ impl PartitionReader for PartitionReaderEnum {
         &self,
         location: &PartitionLocation,
         max_message_size: usize,
+        flight_transport: bool,
     ) -> result::Result<SendableRecordBatchStream, BallistaError> {
         match self {
             PartitionReaderEnum::FlightRemote => {
-                fetch_partition_remote(location, max_message_size).await
+                fetch_partition_remote(location, max_message_size, 
flight_transport).await
             }
             PartitionReaderEnum::Local => 
fetch_partition_local(location).await,
             PartitionReaderEnum::ObjectStoreRemote => {
@@ -480,6 +485,7 @@ impl PartitionReader for PartitionReaderEnum {
 async fn fetch_partition_remote(
     location: &PartitionLocation,
     max_message_size: usize,
+    flight_transport: bool,
 ) -> result::Result<SendableRecordBatchStream, BallistaError> {
     let metadata = &location.executor_meta;
     let partition_id = &location.partition_id;
@@ -501,7 +507,14 @@ async fn fetch_partition_remote(
         })?;
 
     ballista_client
-        .fetch_partition(&metadata.id, partition_id, &location.path, host, 
port)
+        .fetch_partition(
+            &metadata.id,
+            partition_id,
+            &location.path,
+            host,
+            port,
+            flight_transport,
+        )
         .await
 }
 
@@ -936,6 +949,7 @@ mod tests {
             max_request_num,
             4 * 1024 * 1024,
             false,
+            true,
         );
 
         let stream = RecordBatchStreamAdapter::new(
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 5c4ba944..49bff2de 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -18,7 +18,7 @@
 use crate::config::{
     BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME,
     BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, 
BALLISTA_SHUFFLE_READER_MAX_REQUESTS,
-    BALLISTA_STANDALONE_PARALLELISM,
+    BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, 
BALLISTA_STANDALONE_PARALLELISM,
 };
 use crate::planner::BallistaQueryPlanner;
 use crate::serde::protobuf::KeyValuePair;
@@ -133,6 +133,13 @@ pub trait SessionConfigExt {
         self,
         max_requests: usize,
     ) -> Self;
+
+    fn ballista_shuffle_reader_remote_prefer_flight(&self) -> bool;
+
+    fn with_ballista_shuffle_reader_remote_prefer_flight(
+        self,
+        prefer_flight: bool,
+    ) -> Self;
 }
 
 /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods
@@ -357,6 +364,28 @@ impl SessionConfigExt for SessionConfig {
                 .set_bool(BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, 
force_remote_read)
         }
     }
+
+    fn ballista_shuffle_reader_remote_prefer_flight(&self) -> bool {
+        self.options()
+            .extensions
+            .get::<BallistaConfig>()
+            .map(|c| c.shuffle_reader_remote_prefer_flight())
+            .unwrap_or_else(|| {
+                BallistaConfig::default().shuffle_reader_remote_prefer_flight()
+            })
+    }
+
+    fn with_ballista_shuffle_reader_remote_prefer_flight(
+        self,
+        prefer_flight: bool,
+    ) -> Self {
+        if self.options().extensions.get::<BallistaConfig>().is_some() {
+            self.set_bool(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, 
prefer_flight)
+        } else {
+            self.with_option_extension(BallistaConfig::default())
+                .set_bool(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, 
prefer_flight)
+        }
+    }
 }
 
 impl SessionConfigHelperExt for SessionConfig {
diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml
index fe2f2b1f..de49673e 100644
--- a/ballista/executor/Cargo.toml
+++ b/ballista/executor/Cargo.toml
@@ -52,6 +52,7 @@ parking_lot = { workspace = true }
 tempfile = { workspace = true }
 tokio = { workspace = true, features = ["full"] }
 tokio-stream = { workspace = true, features = ["net"] }
+tokio-util = { version = "0.7", features = ["io-util"] }
 tonic = { workspace = true }
 tracing = { workspace = true, optional = true }
 tracing-appender = { workspace = true, optional = true }
diff --git a/ballista/executor/src/flight_service.rs 
b/ballista/executor/src/flight_service.rs
index cee3dec2..0de791f6 100644
--- a/ballista/executor/src/flight_service.rs
+++ b/ballista/executor/src/flight_service.rs
@@ -21,6 +21,7 @@ use datafusion::arrow::ipc::reader::StreamReader;
 use std::convert::TryFrom;
 use std::fs::File;
 use std::pin::Pin;
+use tokio_util::io::ReaderStream;
 
 use arrow_flight::encode::FlightDataEncoderBuilder;
 use arrow_flight::error::FlightError;
@@ -65,6 +66,9 @@ impl Default for BallistaFlightService {
 type BoxedFlightStream<T> =
     Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
 
+/// shuffle file block transfer size    
+const BLOCK_BUFFER_CAPACITY: usize = 8 * 1024 * 1024;
+
 #[tonic::async_trait]
 impl FlightService for BallistaFlightService {
     type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
@@ -185,16 +189,68 @@ impl FlightService for BallistaFlightService {
     ) -> Result<Response<Self::DoActionStream>, Status> {
         let action = request.into_inner();
 
-        let _action = decode_protobuf(&action.body).map_err(|e| 
from_ballista_err(&e))?;
-
-        Err(Status::unimplemented("do_action"))
+        match action.r#type.as_str() {
+            // Block transfer will transfer arrow ipc file block by block
+            // without decoding or decompressing, this will provide less 
resource utilization
+            // as file are not decoded nor decompressed/compressed. Usually 
this would transfer less data across
+            // as files are better compressed due to its size.
+            //
+            // For further discussion regarding performance implications, 
refer to:
+            // https://github.com/apache/datafusion-ballista/issues/1315
+            "IO_BLOCK_TRANSPORT" => {
+                let action =
+                    decode_protobuf(&action.body).map_err(|e| 
from_ballista_err(&e))?;
+
+                match &action {
+                    BallistaAction::FetchPartition { path, .. } => {
+                        debug!("FetchPartition reading {path}");
+                        let file = 
tokio::fs::File::open(&path).await.map_err(|e| {
+                            Status::internal(format!("Failed to open file: 
{e}"))
+                        })?;
+
+                        debug!(
+                            "streaming file: {} with size: {}",
+                            path,
+                            file.metadata().await?.len()
+                        );
+                        let reader = tokio::io::BufReader::with_capacity(
+                            BLOCK_BUFFER_CAPACITY,
+                            file,
+                        );
+                        let file_stream =
+                            ReaderStream::with_capacity(reader, 
BLOCK_BUFFER_CAPACITY);
+
+                        let flight_data_stream = file_stream.map(|result| {
+                            result
+                                .map(|bytes| arrow_flight::Result { body: 
bytes })
+                                .map_err(|e| Status::internal(format!("I/O 
error: {e}")))
+                        });
+
+                        Ok(Response::new(
+                            Box::pin(flight_data_stream) as 
Self::DoActionStream
+                        ))
+                    }
+                }
+            }
+            action_type => Err(Status::unimplemented(format!(
+                "do_action does not implement: {}",
+                action_type
+            ))),
+        }
     }
 
     async fn list_actions(
         &self,
         _request: Request<Empty>,
     ) -> Result<Response<Self::ListActionsStream>, Status> {
-        Err(Status::unimplemented("list_actions"))
+        let actions = vec![Ok(ActionType {
+            r#type: "IO_BLOCK_TRANSFER".to_owned(),
+            description: "optimized shuffle data transfer".to_owned(),
+        })];
+
+        Ok(Response::new(
+            Box::pin(futures::stream::iter(actions)) as Self::ListActionsStream
+        ))
     }
 
     async fn do_exchange(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org


Reply via email to