This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push: new f305a76e feat: Improve Remote Shuffle Read Speed and Resource Utilisation (#1318) f305a76e is described below commit f305a76e1e0da0dbf80c6bd6d9f033e1845b3a5e Author: Marko Milenković <milenkov...@users.noreply.github.com> AuthorDate: Sat Sep 13 19:03:11 2025 +0100 feat: Improve Remote Shuffle Read Speed and Resource Utilisation (#1318) * add block stream implementation * integrate do_action * add flight server action * add config option * fix action name * fix clippy issues * add config test * remove drop * improve messages * minor comment update * minor comment * fix clippy --- Cargo.lock | 1 + ballista/client/tests/context_checks.rs | 63 ++++ ballista/core/src/client.rs | 390 ++++++++++++++++++++- ballista/core/src/config.rs | 13 + .../core/src/execution_plans/distributed_query.rs | 4 +- .../core/src/execution_plans/shuffle_reader.rs | 22 +- ballista/core/src/extension.rs | 31 +- ballista/executor/Cargo.toml | 1 + ballista/executor/src/flight_service.rs | 64 +++- 9 files changed, 570 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 58d3ae03..4e721e80 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1036,6 +1036,7 @@ dependencies = [ "tempfile", "tokio", "tokio-stream", + "tokio-util", "tonic", "tracing", "tracing-appender", diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index a4392f77..638add1c 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -872,4 +872,67 @@ mod supported { Ok(()) } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_force_local_read_with_flight( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) -> datafusion::error::Result<()> { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + ctx.sql("SET ballista.shuffle.force_remote_read = true") + .await? + .show() + .await?; + + ctx.sql("SET ballista.shuffle.remote_read_prefer_flight = true") + .await? + .show() + .await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.shuffle.remote_read_prefer_flight' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+--------------------------------------------+-------+", + "| name | value |", + "+--------------------------------------------+-------+", + "| ballista.shuffle.remote_read_prefer_flight | true |", + "+--------------------------------------------+-------+", + ]; + + assert_batches_eq!(expected, &result); + + let expected = [ + "+------------+----------+", + "| string_col | count(*) |", + "+------------+----------+", + "| 30 | 1 |", + "| 31 | 2 |", + "+------------+----------+", + ]; + + let result = ctx + .sql("select string_col, count(*) from test where id > 4 group by string_col order by string_col") + .await? + .collect() + .await?; + + assert_batches_eq!(expected, &result); + + Ok(()) + } } diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs index c1da64ac..528570f9 100644 --- a/ballista/core/src/client.rs +++ b/ballista/core/src/client.rs @@ -25,7 +25,7 @@ use std::{ task::{Context, Poll}, }; -use crate::error::{BallistaError, Result}; +use crate::error::{BallistaError, Result as BResult}; use crate::serde::scheduler::{Action, PartitionId}; use arrow_flight; @@ -33,12 +33,16 @@ use arrow_flight::utils::flight_data_to_arrow_batch; use arrow_flight::Ticket; use arrow_flight::{flight_service_client::FlightServiceClient, FlightData}; use datafusion::arrow::array::ArrayRef; +use datafusion::arrow::buffer::{Buffer, MutableBuffer}; +use datafusion::arrow::ipc::convert::try_schema_from_ipc_buffer; +use datafusion::arrow::ipc::reader::StreamDecoder; use datafusion::arrow::{ datatypes::{Schema, SchemaRef}, error::ArrowError, record_batch::RecordBatch, }; use datafusion::error::DataFusionError; +use datafusion::error::Result; use crate::serde::protobuf; use crate::utils::create_grpc_client_connection; @@ -61,7 +65,11 @@ const IO_RETRY_WAIT_TIME_MS: u64 = 3000; impl BallistaClient { /// Create a new BallistaClient to connect to the executor listening on the specified /// host and port - pub async fn try_new(host: &str, port: u16, max_message_size: usize) -> Result<Self> { + pub async fn try_new( + host: &str, + port: u16, + max_message_size: usize, + ) -> BResult<Self> { let addr = format!("http://{host}:{port}"); debug!("BallistaClient connecting to {addr}"); let connection = @@ -81,7 +89,11 @@ impl BallistaClient { Ok(Self { flight_client }) } - /// Fetch a partition from an executor + /// Retrieves a partition from an executor. + /// + /// Depending on the value of the `flight_transport` parameter, this method will utilize either + /// the Arrow Flight protocol for compatibility, or a more efficient block-based transfer mechanism. + /// The block-based transfer is optimized for performance and reduces computational overhead on the server. pub async fn fetch_partition( &mut self, executor_id: &str, @@ -89,7 +101,8 @@ impl BallistaClient { path: &str, host: &str, port: u16, - ) -> Result<SendableRecordBatchStream> { + flight_transport: bool, + ) -> BResult<SendableRecordBatchStream> { let action = Action::FetchPartition { job_id: partition_id.job_id.clone(), stage_id: partition_id.stage_id, @@ -98,8 +111,14 @@ impl BallistaClient { host: host.to_owned(), port, }; - self.execute_action(&action) - .await + + let result = if flight_transport { + self.execute_do_get(&action).await + } else { + self.execute_do_action(&action).await + }; + + result .map_err(|error| match error { // map grpc connection error to partition fetch error. BallistaError::GrpcActionError(msg) => { @@ -122,11 +141,15 @@ impl BallistaClient { }) } - /// Execute an action and retrieve the results - pub async fn execute_action( + /// Executes the specified action and retrieves the results from the remote executor. + /// + /// This method establishes a [FlightDataStream] to facilitate the transfer of data + /// using the Arrow Flight protocol. The [FlightDataStream] handles the streaming + /// of record batches from the server to the client in an efficient and structured manner. + pub async fn execute_do_get( &mut self, action: &Action, - ) -> Result<SendableRecordBatchStream> { + ) -> BResult<SendableRecordBatchStream> { let serialized_action: protobuf::Action = action.to_owned().try_into()?; let mut buf: Vec<u8> = Vec::with_capacity(serialized_action.encoded_len()); @@ -197,8 +220,85 @@ impl BallistaClient { } unreachable!("Did not receive schema batch from flight server"); } + + /// Executes the specified action and retrieves the results from the remote executor + /// using an optimized block-based transfer operation. This method establishes a + /// [BlockDataStream] to facilitate efficient transmission of data blocks, reducing + /// computational overhead and improving performance compared to flight protocols. + pub async fn execute_do_action( + &mut self, + action: &Action, + ) -> BResult<SendableRecordBatchStream> { + let serialized_action: protobuf::Action = action.to_owned().try_into()?; + + let mut buf: Vec<u8> = Vec::with_capacity(serialized_action.encoded_len()); + + serialized_action + .encode(&mut buf) + .map_err(|e| BallistaError::GrpcActionError(format!("{e:?}")))?; + + for i in 0..IO_RETRIES_TIMES { + if i > 0 { + warn!( + "Remote shuffle read fail, retry {i} times, sleep {IO_RETRY_WAIT_TIME_MS} ms." + ); + tokio::time::sleep(std::time::Duration::from_millis( + IO_RETRY_WAIT_TIME_MS, + )) + .await; + } + + let request = tonic::Request::new(arrow_flight::Action { + body: buf.clone().into(), + r#type: "IO_BLOCK_TRANSPORT".to_string(), + }); + let result = self.flight_client.do_action(request).await; + let res = match result { + Ok(res) => res, + Err(ref err) => { + // IO related error like connection timeout, reset... will warp with Code::Unknown + // This means IO related error will retry. + if i == IO_RETRIES_TIMES - 1 || err.code() != Code::Unknown { + return BallistaError::GrpcActionError(format!( + "{:?}", + result.unwrap_err() + )) + .into(); + } + // retry request + continue; + } + }; + + let stream = res.into_inner(); + let stream = stream.map(|m| { + m.map(|b| b.body).map_err(|e| { + DataFusionError::ArrowError( + Box::new(ArrowError::IpcError(e.to_string())), + None, + ) + }) + }); + + return Ok(Box::pin(BlockDataStream::try_new(stream).await?)); + } + unreachable!("Did not receive schema batch from flight server"); + } } +/// [FlightDataStream] facilitates the transfer of shuffle data using the Arrow Flight protocol. +/// Internally, it invokes the `do_get` method on the Arrow Flight server, which returns a stream +/// of messages, each representing a record batch. +/// +/// The Flight server is responsible for decompressing and decoding the shuffle file, and then +/// transmitting each batch as an individual message. Each message is compressed independently. +/// +/// This approach increases the computational load on the Flight server due to repeated +/// decompression and compression operations. Furthermore, compression efficiency is reduced +/// compared to file-level compression, as it operates on smaller data segments. +/// +/// For further discussion regarding performance implications, refer to: +/// https://github.com/apache/datafusion-ballista/issues/1315 struct FlightDataStream { stream: Streaming<FlightData>, schema: SchemaRef, @@ -246,3 +346,275 @@ impl RecordBatchStream for FlightDataStream { self.schema.clone() } } +/// [BlockDataStream] facilitates the transfer of original shuffle files in a block-by-block manner. +/// This implementation utilizes a custom `do_action` method on the Arrow Flight server. +/// The primary distinction from [FlightDataStream] is that it does not decompress or decode +/// the original partition file on the server side. This approach reduces computational overhead +/// on the Flight server and enables the transmission of less data, owing to improved file-level compression. +/// +/// For a detailed discussion of the performance advantages, see: +/// https://github.com/apache/datafusion-ballista/issues/1315 +pub struct BlockDataStream<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> { + decoder: StreamDecoder, + state_buffer: Buffer, + ipc_stream: S, + transmitted: usize, + pub schema: SchemaRef, +} + +/// maximum length of message with schema definition +const MAXIMUM_SCHEMA_BUFFER_SIZE: usize = 8_388_608; + +impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> BlockDataStream<S> { + pub async fn try_new( + mut ipc_stream: S, + ) -> std::result::Result<Self, DataFusionError> { + let mut state_buffer = Buffer::default(); + + loop { + if state_buffer.len() > MAXIMUM_SCHEMA_BUFFER_SIZE { + return Err(ArrowError::IpcError(format!( + "Schema buffer length exceeded maximum buffer size, expected {} actual: {}", + MAXIMUM_SCHEMA_BUFFER_SIZE, + state_buffer.len() + )).into()); + } + + match ipc_stream.next().await { + Some(Ok(blob)) => { + state_buffer = + Self::combine_buffers(&state_buffer, &Buffer::from(blob)); + + match try_schema_from_ipc_buffer(state_buffer.as_slice()) { + Ok(schema) => { + return Ok(Self { + decoder: StreamDecoder::new(), + transmitted: state_buffer.len(), + state_buffer, + ipc_stream, + schema: Arc::new(schema), + }); + } + Err(ArrowError::ParseError(_)) => { + // + // parse errors are ignored as may have not received whole message + // thus schema may not be extracted + // + } + Err(e) => return Err(e.into()), + } + } + Some(Err(e)) => return Err(ArrowError::IpcError(e.to_string()).into()), + None => { + return Err(ArrowError::IpcError( + "Premature end of the stream while decoding schema".to_owned(), + ) + .into()); + } + } + } + } +} + +impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> BlockDataStream<S> { + fn combine_buffers(first: &Buffer, second: &Buffer) -> Buffer { + let mut combined = MutableBuffer::new(first.len() + second.len()); + combined.extend_from_slice(first.as_slice()); + combined.extend_from_slice(second.as_slice()); + combined.into() + } + + fn decode(&mut self) -> std::result::Result<Option<RecordBatch>, ArrowError> { + self.decoder.decode(&mut self.state_buffer) + } + + fn extend_bytes(&mut self, blob: prost::bytes::Bytes) { + // + //TODO: do we want to limit maximum buffer size here as well? + // + self.transmitted += blob.len(); + self.state_buffer = Self::combine_buffers(&self.state_buffer, &Buffer::from(blob)) + } +} + +impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> Stream + for BlockDataStream<S> +{ + type Item = datafusion::error::Result<RecordBatch>; + + fn poll_next( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll<Option<Self::Item>> { + match self.decode() { + // + // if there is a batch to be read from state buffer return it + // + Ok(Some(batch)) => std::task::Poll::Ready(Some(Ok(batch))), + // + // there is no batch in the state buffer, try to pull new data + // from remote ipc decode it try to return next batch + // + Ok(None) => match self.ipc_stream.poll_next_unpin(cx) { + std::task::Poll::Ready(Some(flight_data_result)) => { + match flight_data_result { + Ok(blob) => { + self.extend_bytes(blob); + + match self.decode() { + Ok(Some(batch)) => { + std::task::Poll::Ready(Some(Ok(batch))) + } + Ok(None) => { + cx.waker().wake_by_ref(); + std::task::Poll::Pending + } + Err(e) => std::task::Poll::Ready(Some(Err( + ArrowError::IpcError(e.to_string()).into(), + ))), + } + } + Err(e) => std::task::Poll::Ready(Some(Err( + ArrowError::IpcError(e.to_string()).into(), + ))), + } + } + // + // end of IPC stream + // + std::task::Poll::Ready(None) => std::task::Poll::Ready(None), + // its expected that underlying stream will register waker callback + std::task::Poll::Pending => std::task::Poll::Pending, + }, + Err(e) => std::task::Poll::Ready(Some(Err(ArrowError::IpcError( + e.to_string(), + ) + .into()))), + } + } +} + +impl<S: Stream<Item = Result<prost::bytes::Bytes>> + Unpin> RecordBatchStream + for BlockDataStream<S> +{ + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use datafusion::arrow::{ + array::{DictionaryArray, Int32Array, RecordBatch}, + datatypes::Int32Type, + ipc::writer::StreamWriter, + }; + use futures::{StreamExt, TryStreamExt}; + use prost::bytes::Bytes; + + use crate::client::BlockDataStream; + + fn generate_batches() -> Vec<RecordBatch> { + let batch0 = RecordBatch::try_from_iter([ + ("a", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _), + ( + "b", + Arc::new(Int32Array::from(vec![11, 22, 33, 44, 55])) as _, + ), + ( + "c", + Arc::new(DictionaryArray::<Int32Type>::from_iter([ + "hello", "hello", "world", "some", "other", + ])) as _, + ), + ]) + .unwrap(); + + let batch1 = RecordBatch::try_from_iter([ + ( + "a", + Arc::new(Int32Array::from(vec![10, 20, 30, 40, 50])) as _, + ), + ( + "b", + Arc::new(Int32Array::from(vec![110, 220, 330, 440, 550])) as _, + ), + ( + "c", + Arc::new(DictionaryArray::<Int32Type>::from_iter([ + "hello", "some", "world", "some", "other", + ])) as _, + ), + ]) + .unwrap(); + + vec![batch0, batch1] + } + + fn generate_ipc_stream(batches: &[RecordBatch]) -> Vec<u8> { + let mut result = vec![]; + let mut writer = + StreamWriter::try_new(&mut result, &batches[0].schema()).unwrap(); + for b in batches { + writer.write(b).unwrap(); + } + + writer.finish().unwrap(); + result + } + + #[tokio::test] + async fn should_process_chunked() { + let batches = generate_batches(); + let ipc_blob = generate_ipc_stream(&batches); + let stream = futures::stream::iter(ipc_blob) + .chunks(2) + .map(|b| Ok(Bytes::from(b))); + + let result: datafusion::error::Result<Vec<RecordBatch>> = + BlockDataStream::try_new(stream) + .await + .unwrap() + .try_collect() + .await; + + assert_eq!(batches, result.unwrap()) + } + + #[tokio::test] + async fn should_process_single_message() { + let batches = generate_batches(); + let blob = generate_ipc_stream(&batches); + let stream = futures::stream::iter(vec![Ok(Bytes::from(blob))]); + + let result: datafusion::error::Result<Vec<RecordBatch>> = + BlockDataStream::try_new(stream) + .await + .unwrap() + .try_collect() + .await; + + assert_eq!(batches, result.unwrap()) + } + + #[tokio::test] + #[should_panic = "Premature end of the stream while decoding schema"] + async fn should_process_panic_if_not_correct_stream() { + let batches = generate_batches(); + let ipc_blob = generate_ipc_stream(&batches); + let stream = futures::stream::iter(ipc_blob[..5].to_vec()) + .chunks(2) + .map(|b| Ok(Bytes::from(b))); + + let result: datafusion::error::Result<Vec<RecordBatch>> = + BlockDataStream::try_new(stream) + .await + .unwrap() + .try_collect() + .await; + + assert_eq!(batches, result.unwrap()) + } +} diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index 823d1d16..ff7766de 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -36,6 +36,8 @@ pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str = "ballista.shuffle.max_concurrent_read_requests"; pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str = "ballista.shuffle.force_remote_read"; +pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str = + "ballista.shuffle.remote_read_prefer_flight"; pub type ParseResult<T> = result::Result<T, String>; use std::sync::LazyLock; @@ -60,6 +62,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, ConfigEntry>> = LazyLock::new(|| "Forces the shuffle reader to always read partitions via the Arrow Flight client, even when partitions are local to the node.".to_string(), DataType::Boolean, Some((false).to_string())), + ConfigEntry::new(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT.to_string(), + "Forces the shuffle reader to use flight reader instead of block reader for remote read. Block reader usually has better performance and resource utilization".to_string(), + DataType::Boolean, + Some((false).to_string())), ]; entries @@ -191,6 +197,13 @@ impl BallistaConfig { pub fn shuffle_reader_force_remote_read(&self) -> bool { self.get_bool_setting(BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ) } + /// Forces the shuffle reader to prefer flight protocol over block protocol + /// to read remote shuffle partition. + /// + /// Block protocol is usually more performant than flight protocol + pub fn shuffle_reader_remote_prefer_flight(&self) -> bool { + self.get_bool_setting(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT) + } fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { diff --git a/ballista/core/src/execution_plans/distributed_query.rs b/ballista/core/src/execution_plans/distributed_query.rs index c352c987..eab3c0d0 100644 --- a/ballista/core/src/execution_plans/distributed_query.rs +++ b/ballista/core/src/execution_plans/distributed_query.rs @@ -357,7 +357,7 @@ async fn execute_query( info!("Job {job_id} finished executing in {duration:?} "); let streams = partition_location.into_iter().map(move |partition| { - let f = fetch_partition(partition, max_message_size) + let f = fetch_partition(partition, max_message_size, true) .map_err(|e| ArrowError::ExternalError(Box::new(e))); futures::stream::once(f).try_flatten() @@ -372,6 +372,7 @@ async fn execute_query( async fn fetch_partition( location: PartitionLocation, max_message_size: usize, + flight_transport: bool, ) -> Result<SendableRecordBatchStream> { let metadata = location.executor_meta.ok_or_else(|| { DataFusionError::Internal("Received empty executor metadata".to_owned()) @@ -391,6 +392,7 @@ async fn fetch_partition( &location.path, host, port, + flight_transport, ) .await .map_err(|e| DataFusionError::External(Box::new(e))) diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 53624d17..617654ec 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -161,6 +161,7 @@ impl ExecutionPlan for ShuffleReaderExec { config.ballista_shuffle_reader_maximum_concurrent_requests(); let max_message_size = config.ballista_grpc_client_max_message_size(); let force_remote_read = config.ballista_shuffle_reader_force_remote_read(); + let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight(); if force_remote_read { debug!( @@ -193,6 +194,7 @@ impl ExecutionPlan for ShuffleReaderExec { max_request_num, max_message_size, force_remote_read, + prefer_flight, ); let result = RecordBatchStreamAdapter::new( @@ -386,6 +388,7 @@ fn send_fetch_partitions( max_request_num: usize, max_message_size: usize, force_remote_read: bool, + flight_transport: bool, ) -> AbortableReceiverStream { let (response_sender, response_receiver) = mpsc::channel(max_request_num); let semaphore = Arc::new(Semaphore::new(max_request_num)); @@ -405,7 +408,7 @@ fn send_fetch_partitions( spawned_tasks.push(SpawnedTask::spawn(async move { for p in local_locations { let r = PartitionReaderEnum::Local - .fetch_partition(&p, max_message_size) + .fetch_partition(&p, max_message_size, flight_transport) .await; if let Err(e) = response_sender_c.send(r).await { error!("Fail to send response event to the channel due to {e}"); @@ -420,7 +423,7 @@ fn send_fetch_partitions( // Block if exceeds max request number. let permit = semaphore.acquire_owned().await.unwrap(); let r = PartitionReaderEnum::FlightRemote - .fetch_partition(&p, max_message_size) + .fetch_partition(&p, max_message_size, flight_transport) .await; // Block if the channel buffer is full. if let Err(e) = response_sender.send(r).await { @@ -446,6 +449,7 @@ trait PartitionReader: Send + Sync + Clone { &self, location: &PartitionLocation, max_message_size: usize, + flight_transport: bool, ) -> result::Result<SendableRecordBatchStream, BallistaError>; } @@ -464,10 +468,11 @@ impl PartitionReader for PartitionReaderEnum { &self, location: &PartitionLocation, max_message_size: usize, + flight_transport: bool, ) -> result::Result<SendableRecordBatchStream, BallistaError> { match self { PartitionReaderEnum::FlightRemote => { - fetch_partition_remote(location, max_message_size).await + fetch_partition_remote(location, max_message_size, flight_transport).await } PartitionReaderEnum::Local => fetch_partition_local(location).await, PartitionReaderEnum::ObjectStoreRemote => { @@ -480,6 +485,7 @@ impl PartitionReader for PartitionReaderEnum { async fn fetch_partition_remote( location: &PartitionLocation, max_message_size: usize, + flight_transport: bool, ) -> result::Result<SendableRecordBatchStream, BallistaError> { let metadata = &location.executor_meta; let partition_id = &location.partition_id; @@ -501,7 +507,14 @@ async fn fetch_partition_remote( })?; ballista_client - .fetch_partition(&metadata.id, partition_id, &location.path, host, port) + .fetch_partition( + &metadata.id, + partition_id, + &location.path, + host, + port, + flight_transport, + ) .await } @@ -936,6 +949,7 @@ mod tests { max_request_num, 4 * 1024 * 1024, false, + true, ); let stream = RecordBatchStreamAdapter::new( diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs index 5c4ba944..49bff2de 100644 --- a/ballista/core/src/extension.rs +++ b/ballista/core/src/extension.rs @@ -18,7 +18,7 @@ use crate::config::{ BallistaConfig, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, BALLISTA_SHUFFLE_READER_MAX_REQUESTS, - BALLISTA_STANDALONE_PARALLELISM, + BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, BALLISTA_STANDALONE_PARALLELISM, }; use crate::planner::BallistaQueryPlanner; use crate::serde::protobuf::KeyValuePair; @@ -133,6 +133,13 @@ pub trait SessionConfigExt { self, max_requests: usize, ) -> Self; + + fn ballista_shuffle_reader_remote_prefer_flight(&self) -> bool; + + fn with_ballista_shuffle_reader_remote_prefer_flight( + self, + prefer_flight: bool, + ) -> Self; } /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods @@ -357,6 +364,28 @@ impl SessionConfigExt for SessionConfig { .set_bool(BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, force_remote_read) } } + + fn ballista_shuffle_reader_remote_prefer_flight(&self) -> bool { + self.options() + .extensions + .get::<BallistaConfig>() + .map(|c| c.shuffle_reader_remote_prefer_flight()) + .unwrap_or_else(|| { + BallistaConfig::default().shuffle_reader_remote_prefer_flight() + }) + } + + fn with_ballista_shuffle_reader_remote_prefer_flight( + self, + prefer_flight: bool, + ) -> Self { + if self.options().extensions.get::<BallistaConfig>().is_some() { + self.set_bool(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, prefer_flight) + } else { + self.with_option_extension(BallistaConfig::default()) + .set_bool(BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, prefer_flight) + } + } } impl SessionConfigHelperExt for SessionConfig { diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index fe2f2b1f..de49673e 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -52,6 +52,7 @@ parking_lot = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true, features = ["full"] } tokio-stream = { workspace = true, features = ["net"] } +tokio-util = { version = "0.7", features = ["io-util"] } tonic = { workspace = true } tracing = { workspace = true, optional = true } tracing-appender = { workspace = true, optional = true } diff --git a/ballista/executor/src/flight_service.rs b/ballista/executor/src/flight_service.rs index cee3dec2..0de791f6 100644 --- a/ballista/executor/src/flight_service.rs +++ b/ballista/executor/src/flight_service.rs @@ -21,6 +21,7 @@ use datafusion::arrow::ipc::reader::StreamReader; use std::convert::TryFrom; use std::fs::File; use std::pin::Pin; +use tokio_util::io::ReaderStream; use arrow_flight::encode::FlightDataEncoderBuilder; use arrow_flight::error::FlightError; @@ -65,6 +66,9 @@ impl Default for BallistaFlightService { type BoxedFlightStream<T> = Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>; +/// shuffle file block transfer size +const BLOCK_BUFFER_CAPACITY: usize = 8 * 1024 * 1024; + #[tonic::async_trait] impl FlightService for BallistaFlightService { type DoActionStream = BoxedFlightStream<arrow_flight::Result>; @@ -185,16 +189,68 @@ impl FlightService for BallistaFlightService { ) -> Result<Response<Self::DoActionStream>, Status> { let action = request.into_inner(); - let _action = decode_protobuf(&action.body).map_err(|e| from_ballista_err(&e))?; - - Err(Status::unimplemented("do_action")) + match action.r#type.as_str() { + // Block transfer will transfer arrow ipc file block by block + // without decoding or decompressing, this will provide less resource utilization + // as file are not decoded nor decompressed/compressed. Usually this would transfer less data across + // as files are better compressed due to its size. + // + // For further discussion regarding performance implications, refer to: + // https://github.com/apache/datafusion-ballista/issues/1315 + "IO_BLOCK_TRANSPORT" => { + let action = + decode_protobuf(&action.body).map_err(|e| from_ballista_err(&e))?; + + match &action { + BallistaAction::FetchPartition { path, .. } => { + debug!("FetchPartition reading {path}"); + let file = tokio::fs::File::open(&path).await.map_err(|e| { + Status::internal(format!("Failed to open file: {e}")) + })?; + + debug!( + "streaming file: {} with size: {}", + path, + file.metadata().await?.len() + ); + let reader = tokio::io::BufReader::with_capacity( + BLOCK_BUFFER_CAPACITY, + file, + ); + let file_stream = + ReaderStream::with_capacity(reader, BLOCK_BUFFER_CAPACITY); + + let flight_data_stream = file_stream.map(|result| { + result + .map(|bytes| arrow_flight::Result { body: bytes }) + .map_err(|e| Status::internal(format!("I/O error: {e}"))) + }); + + Ok(Response::new( + Box::pin(flight_data_stream) as Self::DoActionStream + )) + } + } + } + action_type => Err(Status::unimplemented(format!( + "do_action does not implement: {}", + action_type + ))), + } } async fn list_actions( &self, _request: Request<Empty>, ) -> Result<Response<Self::ListActionsStream>, Status> { - Err(Status::unimplemented("list_actions")) + let actions = vec![Ok(ActionType { + r#type: "IO_BLOCK_TRANSFER".to_owned(), + description: "optimized shuffle data transfer".to_owned(), + })]; + + Ok(Response::new( + Box::pin(futures::stream::iter(actions)) as Self::ListActionsStream + )) } async fn do_exchange( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org