This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new bd45acbf0 feat: (remote) shuffle reader cleanup (#1503)
bd45acbf0 is described below

commit bd45acbf090764b2324d272bca320b3fc3d160bd
Author: Marko Milenković <[email protected]>
AuthorDate: Fri Mar 13 11:54:42 2026 +0000

    feat: (remote) shuffle reader cleanup (#1503)
    
    * shuffle (remote) reader cleanup
    
    * fix review comments
    
    * minor
    
    * skip index check if sort shuffle is disabled
---
 ballista/core/src/client.rs                        |  61 +++--
 ballista/core/src/config.rs                        |  40 +--
 .../core/src/execution_plans/distributed_query.rs  |   8 +-
 .../core/src/execution_plans/shuffle_reader.rs     | 285 ++++++++-------------
 ballista/core/src/extension.rs                     |  38 ++-
 ballista/core/src/utils.rs                         |  28 +-
 examples/examples/standalone-substrait.rs          |   2 -
 7 files changed, 224 insertions(+), 238 deletions(-)

diff --git a/ballista/core/src/client.rs b/ballista/core/src/client.rs
index c01d431ca..fc771fd14 100644
--- a/ballista/core/src/client.rs
+++ b/ballista/core/src/client.rs
@@ -17,17 +17,11 @@
 
 //! Client API for sending requests to executors.
 
-use std::collections::HashMap;
-use std::sync::Arc;
-
-use std::{
-    convert::{TryFrom, TryInto},
-    task::{Context, Poll},
-};
-
 use crate::error::{BallistaError, Result as BResult};
+use crate::extension::BallistaConfigGrpcEndpoint;
+use crate::serde::protobuf;
 use crate::serde::scheduler::{Action, PartitionId};
-
+use crate::utils::create_grpc_client_endpoint;
 use arrow_flight;
 use arrow_flight::Ticket;
 use arrow_flight::utils::flight_data_to_arrow_batch;
@@ -43,21 +37,23 @@ use datafusion::arrow::{
 };
 use datafusion::error::DataFusionError;
 use datafusion::error::Result;
-
-use crate::extension::BallistaConfigGrpcEndpoint;
-use crate::serde::protobuf;
-
-use crate::utils::create_grpc_client_endpoint;
-
 use datafusion::physical_plan::{RecordBatchStream, SendableRecordBatchStream};
 use futures::{Stream, StreamExt};
 use log::{debug, warn};
 use prost::Message;
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::{
+    convert::{TryFrom, TryInto},
+    task::{Context, Poll},
+};
 use tonic::{Code, Streaming};
 
 /// Client for interacting with Ballista executors.
 #[derive(Clone)]
 pub struct BallistaClient {
+    host: String,
+    port: u16,
     flight_client: FlightServiceClient<tonic::transport::channel::Channel>,
 }
 
@@ -109,7 +105,11 @@ impl BallistaClient {
 
         debug!("BallistaClient connected OK: {flight_client:?}");
 
-        Ok(Self { flight_client })
+        Ok(Self {
+            flight_client,
+            host: host.to_string(),
+            port,
+        })
     }
 
     /// Retrieves a partition from an executor.
@@ -117,13 +117,42 @@ impl BallistaClient {
     /// Depending on the value of the `flight_transport` parameter, this 
method will utilize either
     /// the Arrow Flight protocol for compatibility, or a more efficient 
block-based transfer mechanism.
     /// The block-based transfer is optimized for performance and reduces 
computational overhead on the server.
+    ///
+    /// This method is to be used for direct connection to the executor 
holding the required shuffle partition.
     pub async fn fetch_partition(
         &mut self,
         executor_id: &str,
         partition_id: &PartitionId,
         path: &str,
+        flight_transport: bool,
+    ) -> BResult<SendableRecordBatchStream> {
+        let host = self.host.to_owned();
+        let port = self.port;
+        self.fetch_partition_proxied(
+            executor_id,
+            partition_id,
+            &host,
+            port,
+            path,
+            flight_transport,
+        )
+        .await
+    }
+
+    /// Retrieves a partition from an executor.
+    ///
+    /// Depending on the value of the `flight_transport` parameter, this 
method will utilize either
+    /// the Arrow Flight protocol for compatibility, or a more efficient 
block-based transfer mechanism.
+    /// The block-based transfer is optimized for performance and reduces 
computational overhead on the server.
+    ///
+    /// This method should be used if the request may be proxied.
+    pub async fn fetch_partition_proxied(
+        &mut self,
+        executor_id: &str,
+        partition_id: &PartitionId,
         host: &str,
         port: u16,
+        path: &str,
         flight_transport: bool,
     ) -> BResult<SendableRecordBatchStream> {
         let action = Action::FetchPartition {
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index ca151558c..2fde0a2b2 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -18,23 +18,19 @@
 
 //! Ballista configuration
 
-use std::result;
-use std::{collections::HashMap, fmt::Display};
-
 use crate::error::{BallistaError, Result};
-
 use datafusion::{
     arrow::datatypes::DataType, common::config_err, config::ConfigExtension,
 };
+use std::result;
+use std::{collections::HashMap, fmt::Display};
 
 /// Configuration key for setting the job name displayed in the web UI.
 pub const BALLISTA_JOB_NAME: &str = "ballista.job.name";
 /// Configuration key for standalone processing parallelism.
 pub const BALLISTA_STANDALONE_PARALLELISM: &str = 
"ballista.standalone.parallelism";
-
 /// Configuration key for disabling default cache extension node.
 pub const BALLISTA_CACHE_NOOP: &str = "ballista.cache.noop";
-
 /// Configuration key for maximum concurrent shuffle read requests.
 pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str =
     "ballista.shuffle.max_concurrent_read_requests";
@@ -44,7 +40,6 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
 /// Configuration key to prefer Flight protocol for remote shuffle reads.
 pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
     "ballista.shuffle.remote_read_prefer_flight";
-
 /// max message size for gRPC clients
 pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
     "ballista.grpc_client_max_message_size";
@@ -82,6 +77,8 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
     "ballista.shuffle.sort_based.batch_size";
 /// Should client employ pull or push job tracking strategy
 pub const BALLISTA_CLIENT_PULL: &str = "ballista.client.pull";
+/// Should client use tls connection
+pub const BALLISTA_CLIENT_USE_TLS: &str = "ballista.client.use_tls";
 
 /// Result type for configuration parsing operations.
 pub type ParseResult<T> = result::Result<T, String>;
@@ -162,6 +159,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
         ConfigEntry::new(BALLISTA_CLIENT_PULL.to_string(),
                          "Should client employ pull or push job tracking. In 
pull mode client will make a request to server in the loop, until job finishes. 
Pull mode is kept for legacy clients.".to_string(),
                          DataType::Boolean,
+                         Some(false.to_string())),
+        ConfigEntry::new(BALLISTA_CLIENT_USE_TLS.to_string(),
+                         "Should connection between client, scheduler, and 
executors use TLS.".to_string(),
+                         DataType::Boolean,
                          Some(false.to_string()))
     ];
     entries
@@ -274,11 +275,6 @@ impl BallistaConfig {
         &self.settings
     }
 
-    /// Returns the maximum message size for gRPC clients in bytes.
-    pub fn default_grpc_client_max_message_size(&self) -> usize {
-        self.get_usize_setting(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE)
-    }
-
     /// Returns the standalone processing parallelism level.
     pub fn default_standalone_parallelism(&self) -> usize {
         self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM)
@@ -290,25 +286,30 @@ impl BallistaConfig {
     }
 
     /// Returns the gRPC client connection timeout in seconds.
-    pub fn default_grpc_client_connect_timeout_seconds(&self) -> usize {
+    pub fn grpc_client_connect_timeout_seconds(&self) -> usize {
         self.get_usize_setting(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS)
     }
 
     /// Returns the gRPC client request timeout in seconds.
-    pub fn default_grpc_client_timeout_seconds(&self) -> usize {
+    pub fn grpc_client_timeout_seconds(&self) -> usize {
         self.get_usize_setting(BALLISTA_GRPC_CLIENT_TIMEOUT_SECONDS)
     }
 
     /// Returns the TCP keep-alive interval for gRPC clients in seconds.
-    pub fn default_grpc_client_tcp_keepalive_seconds(&self) -> usize {
+    pub fn grpc_client_tcp_keepalive_seconds(&self) -> usize {
         self.get_usize_setting(BALLISTA_GRPC_CLIENT_TCP_KEEPALIVE_SECONDS)
     }
 
     /// Returns the HTTP/2 keep-alive interval for gRPC clients in seconds.
-    pub fn default_grpc_client_http2_keepalive_interval_seconds(&self) -> 
usize {
+    pub fn grpc_client_http2_keepalive_interval_seconds(&self) -> usize {
         
self.get_usize_setting(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS)
     }
 
+    /// Returns the maximum message size for gRPC clients in bytes.
+    pub fn grpc_client_max_message_size(&self) -> usize {
+        self.get_usize_setting(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE)
+    }
+
     /// Returns whether the default cache node extension is disabled.
     pub fn cache_noop(&self) -> bool {
         self.get_bool_setting(BALLISTA_CACHE_NOOP)
@@ -373,6 +374,11 @@ impl BallistaConfig {
         self.get_bool_setting(BALLISTA_CLIENT_PULL)
     }
 
+    /// should client use TLS to communicate with ballista cluster
+    pub fn client_use_tls(&self) -> bool {
+        self.get_bool_setting(BALLISTA_CLIENT_USE_TLS)
+    }
+
     fn get_usize_setting(&self, key: &str) -> usize {
         if let Some(v) = self.settings.get(key) {
             // infallible because we validate all configs in the constructor
@@ -539,7 +545,7 @@ mod tests {
     #[test]
     fn default_config() -> Result<()> {
         let config = BallistaConfig::default();
-        assert_eq!(16777216, config.default_grpc_client_max_message_size());
+        assert_eq!(16777216, config.grpc_client_max_message_size());
         Ok(())
     }
 }
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index 5f1ad258d..533ee3541 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -253,7 +253,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
                     self.scheduler_url.clone(),
                     self.session_id.clone(),
                     query,
-                    self.config.default_grpc_client_max_message_size(),
+                    self.config.grpc_client_max_message_size(),
                     GrpcClientConfig::from(&self.config),
                     Arc::new(self.metrics.clone()),
                     partition,
@@ -280,7 +280,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
                 execute_query_push(
                     self.scheduler_url.clone(),
                     query,
-                    self.config.default_grpc_client_max_message_size(),
+                    self.config.grpc_client_max_message_size(),
                     GrpcClientConfig::from(&self.config),
                     Arc::new(self.metrics.clone()),
                     partition,
@@ -717,12 +717,12 @@ async fn fetch_partition(
     .await
     .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
     ballista_client
-        .fetch_partition(
+        .fetch_partition_proxied(
             &metadata.id,
             &partition_id.into(),
-            &location.path,
             host,
             port,
+            &location.path,
             flight_transport,
         )
         .await
diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs 
b/ballista/core/src/execution_plans/shuffle_reader.rs
index c3676a93b..bc81b2150 100644
--- a/ballista/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/core/src/execution_plans/shuffle_reader.rs
@@ -15,49 +15,46 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use async_trait::async_trait;
-use datafusion::arrow::ipc::reader::StreamReader;
-use datafusion::common::stats::Precision;
-use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, 
PushBatchStatus};
-use std::any::Any;
-use std::collections::HashMap;
-use std::fmt::Debug;
-use std::fs::File;
-use std::io::BufReader;
-use std::pin::Pin;
-use std::result;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
 use crate::client::BallistaClient;
+use crate::error::BallistaError;
 use crate::execution_plans::sort_shuffle::{
     get_index_path, is_sort_shuffle_output, stream_sort_shuffle_partition,
 };
 use crate::extension::{BallistaConfigGrpcEndpoint, SessionConfigExt};
 use crate::serde::scheduler::{PartitionLocation, PartitionStats};
-
+use crate::utils::GrpcClientConfig;
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::arrow::error::ArrowError;
+use datafusion::arrow::ipc::reader::StreamReader;
 use datafusion::arrow::record_batch::RecordBatch;
 use datafusion::common::runtime::SpawnedTask;
-
+use datafusion::common::stats::Precision;
 use datafusion::error::{DataFusionError, Result};
+use datafusion::execution::context::TaskContext;
+use datafusion::physical_plan::coalesce::{LimitedBatchCoalescer, 
PushBatchStatus};
 use datafusion::physical_plan::metrics::{
     BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet,
 };
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use datafusion::physical_plan::{
     ColumnStatistics, DisplayAs, DisplayFormatType, ExecutionPlan, 
Partitioning,
     PlanProperties, RecordBatchStream, SendableRecordBatchStream, Statistics,
 };
+use datafusion::prelude::SessionConfig;
 use futures::{Stream, StreamExt, TryStreamExt, ready};
-
-use crate::error::BallistaError;
-use datafusion::execution::context::TaskContext;
-use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
 use itertools::Itertools;
 use log::{debug, error, trace};
 use rand::prelude::SliceRandom;
 use rand::rng;
+use std::any::Any;
+use std::collections::HashMap;
+use std::fmt::Debug;
+use std::fs::File;
+use std::io::BufReader;
+use std::pin::Pin;
+use std::result;
+use std::sync::Arc;
+use std::task::{Context, Poll};
 use tokio::sync::{Semaphore, mpsc};
 use tokio_stream::wrappers::ReceiverStream;
 
@@ -162,17 +159,9 @@ impl ExecutionPlan for ShuffleReaderExec {
         debug!("ShuffleReaderExec::execute({task_id})");
 
         let config = context.session_config();
-
-        let max_request_num =
-            config.ballista_shuffle_reader_maximum_concurrent_requests();
-        let max_message_size = config.ballista_grpc_client_max_message_size();
-        let force_remote_read = 
config.ballista_shuffle_reader_force_remote_read();
-        let prefer_flight = 
config.ballista_shuffle_reader_remote_prefer_flight();
         let batch_size = config.batch_size();
-        let customize_endpoint = 
config.ballista_override_create_grpc_client_endpoint();
-        let use_tls = config.ballista_use_tls();
 
-        if force_remote_read {
+        if config.ballista_shuffle_reader_force_remote_read() {
             debug!(
                 "All shuffle partitions will be read as remote partitions! To 
disable this behavior set: `{}=false`",
                 crate::config::BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ
@@ -180,7 +169,9 @@ impl ExecutionPlan for ShuffleReaderExec {
         }
 
         log::debug!(
-            "ShuffleReaderExec::execute({task_id}) max_request_num: 
{max_request_num}, max_message_size: {max_message_size}"
+            "ShuffleReaderExec::execute({task_id}) max_request_num: {}, 
max_message_size: {}",
+            config.ballista_shuffle_reader_maximum_concurrent_requests(),
+            config.ballista_grpc_client_max_message_size()
         );
         let mut partition_locations = HashMap::new();
         for p in &self.partition[partition] {
@@ -198,15 +189,7 @@ impl ExecutionPlan for ShuffleReaderExec {
             .collect();
         // Shuffle partitions for evenly send fetching partition requests to 
avoid hot executors within multiple tasks
         partition_locations.shuffle(&mut rng());
-        let response_receiver = send_fetch_partitions(
-            partition_locations,
-            max_request_num,
-            max_message_size,
-            force_remote_read,
-            prefer_flight,
-            customize_endpoint,
-            use_tls,
-        );
+        let response_receiver = send_fetch_partitions(partition_locations, 
config);
 
         let input_stream = Box::pin(RecordBatchStreamAdapter::new(
             self.schema.clone(),
@@ -405,19 +388,19 @@ fn local_remote_read_split(
 
 fn send_fetch_partitions(
     partition_locations: Vec<PartitionLocation>,
-    max_request_num: usize,
-    max_message_size: usize,
-    force_remote_read: bool,
-    flight_transport: bool,
-    customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
-    use_tls: bool,
+    config: &SessionConfig,
 ) -> AbortableReceiverStream {
+    let max_request_num = 
config.ballista_shuffle_reader_maximum_concurrent_requests();
+    let sort_shuffle_enabled = config.ballista_sort_shuffle_enabled();
+
     let (response_sender, response_receiver) = mpsc::channel(max_request_num);
     let semaphore = Arc::new(Semaphore::new(max_request_num));
     let mut spawned_tasks: Vec<SpawnedTask<()>> = vec![];
 
-    let (local_locations, remote_locations): (Vec<_>, Vec<_>) =
-        local_remote_read_split(partition_locations, force_remote_read);
+    let (local_locations, remote_locations): (Vec<_>, Vec<_>) = 
local_remote_read_split(
+        partition_locations,
+        config.ballista_shuffle_reader_force_remote_read(),
+    );
 
     debug!(
         "local shuffle file counts:{}, remote shuffle file count:{}.",
@@ -427,46 +410,53 @@ fn send_fetch_partitions(
 
     // keep local shuffle files reading in serial order for memory control.
     let response_sender_c = response_sender.clone();
-    let customize_endpoint_c = customize_endpoint.clone();
-    spawned_tasks.push(SpawnedTask::spawn(async move {
-        for p in local_locations {
-            let r = PartitionReaderEnum::Local
-                .fetch_partition(
-                    &p,
-                    max_message_size,
-                    flight_transport,
-                    customize_endpoint_c.clone(),
-                    use_tls,
-                )
-                .await;
-            if let Err(e) = response_sender_c.send(r).await {
-                error!("Fail to send response event to the channel due to 
{e}");
+
+    //
+    // fetching local partitions (read from file)
+    //
+
+    spawned_tasks.push(SpawnedTask::spawn_blocking({
+        move || {
+            for p in local_locations {
+                let r = fetch_partition_local(&p, sort_shuffle_enabled);
+                if let Err(e) = response_sender_c.blocking_send(r) {
+                    error!("Fail to send response event to the channel due to 
{e}");
+                }
             }
         }
     }));
 
+    //
+    // fetching remote partitions (uses grpc flight protocol)
+    //
+    let grpc_config: Arc<GrpcClientConfig> = 
Arc::new((&config.ballista_config()).into());
+    let customize_endpoint = 
config.ballista_override_create_grpc_client_endpoint();
+    let prefer_flight = config.ballista_shuffle_reader_remote_prefer_flight();
+
     for p in remote_locations.into_iter() {
         let semaphore = semaphore.clone();
         let response_sender = response_sender.clone();
-        let customize_endpoint_c = customize_endpoint.clone();
-        spawned_tasks.push(SpawnedTask::spawn(async move {
-            // Block if exceeds max request number.
-            let permit = semaphore.acquire_owned().await.unwrap();
-            let r = PartitionReaderEnum::FlightRemote
-                .fetch_partition(
+
+        spawned_tasks.push(SpawnedTask::spawn({
+            let customize_endpoint = customize_endpoint.clone();
+            let grpc_config = grpc_config.clone();
+            async move {
+                // Block if exceeds max request number.
+                let permit = semaphore.acquire_owned().await.unwrap();
+                let r = fetch_partition_remote(
                     &p,
-                    max_message_size,
-                    flight_transport,
-                    customize_endpoint_c,
-                    use_tls,
+                    grpc_config,
+                    prefer_flight,
+                    customize_endpoint,
                 )
                 .await;
-            // Block if the channel buffer is full.
-            if let Err(e) = response_sender.send(r).await {
-                error!("Fail to send response event to the channel due to 
{e}");
+                // Block if the channel buffer is full.
+                if let Err(e) = response_sender.send(r).await {
+                    error!("Fail to send response event to the channel due to 
{e}");
+                }
+                // Increase semaphore by dropping existing permits.
+                drop(permit);
             }
-            // Increase semaphore by dropping existing permits.
-            drop(permit);
         }));
     }
 
@@ -477,112 +467,68 @@ fn check_is_local_location(location: &PartitionLocation) 
-> bool {
     std::path::Path::new(location.path.as_str()).exists()
 }
 
-/// Partition reader Trait, different partition reader can have
-#[async_trait]
-trait PartitionReader: Send + Sync + Clone {
-    // Read partition data from PartitionLocation
-    async fn fetch_partition(
-        &self,
-        location: &PartitionLocation,
-        max_message_size: usize,
-        flight_transport: bool,
-        customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
-        use_tls: bool,
-    ) -> result::Result<SendableRecordBatchStream, BallistaError>;
-}
-
-#[derive(Clone)]
-enum PartitionReaderEnum {
-    Local,
-    FlightRemote,
-    #[allow(dead_code)]
-    ObjectStoreRemote,
-}
+async fn new_ballista_client(
+    host: &str,
+    port: u16,
+    config: &GrpcClientConfig,
+    customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
+) -> result::Result<BallistaClient, BallistaError> {
+    let max_message_size = config.max_message_size;
+    let use_tls = config.use_tls;
 
-#[async_trait]
-impl PartitionReader for PartitionReaderEnum {
-    // Notice return `BallistaError::FetchFailed` will let scheduler 
re-schedule the task.
-    async fn fetch_partition(
-        &self,
-        location: &PartitionLocation,
-        max_message_size: usize,
-        flight_transport: bool,
-        customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
-        use_tls: bool,
-    ) -> result::Result<SendableRecordBatchStream, BallistaError> {
-        match self {
-            PartitionReaderEnum::FlightRemote => {
-                fetch_partition_remote(
-                    location,
-                    max_message_size,
-                    flight_transport,
-                    customize_endpoint,
-                    use_tls,
-                )
-                .await
-            }
-            PartitionReaderEnum::Local => 
fetch_partition_local(location).await,
-            PartitionReaderEnum::ObjectStoreRemote => {
-                fetch_partition_object_store(location).await
-            }
-        }
-    }
+    BallistaClient::try_new(host, port, max_message_size, use_tls, 
customize_endpoint)
+        .await
 }
 
 async fn fetch_partition_remote(
     location: &PartitionLocation,
-    max_message_size: usize,
-    flight_transport: bool,
+    config: Arc<GrpcClientConfig>,
+    prefer_flight: bool,
     customize_endpoint: Option<Arc<BallistaConfigGrpcEndpoint>>,
-    use_tls: bool,
 ) -> result::Result<SendableRecordBatchStream, BallistaError> {
     let metadata = &location.executor_meta;
     let partition_id = &location.partition_id;
-    // TODO for shuffle client connections, we should avoid creating new 
connections again and again.
-    // And we should also avoid to keep alive too many connections for long 
time.
     let host = metadata.host.as_str();
     let port = metadata.port;
-    let mut ballista_client = BallistaClient::try_new(
-        host,
-        port,
-        max_message_size,
-        use_tls,
-        customize_endpoint,
-    )
-    .await
-    .map_err(|error| match error {
-        // map grpc connection error to partition fetch error.
-        BallistaError::GrpcConnectionError(msg) => BallistaError::FetchFailed(
-            metadata.id.clone(),
-            partition_id.stage_id,
-            partition_id.partition_id,
-            msg,
-        ),
-        other => other,
-    })?;
+
+    // TODO for shuffle client connections, we should avoid creating new 
connections again and again.
+    // And we should also avoid to keep alive too many connections for long 
time.
+    let mut ballista_client =
+        new_ballista_client(host, port, &config, customize_endpoint)
+            .await
+            .map_err(|error| match error {
+                // map grpc connection error to partition fetch error.
+                BallistaError::GrpcConnectionError(msg) => 
BallistaError::FetchFailed(
+                    metadata.id.clone(),
+                    partition_id.stage_id,
+                    partition_id.partition_id,
+                    msg,
+                ),
+                other => other,
+            })?;
 
     ballista_client
-        .fetch_partition(
-            &metadata.id,
-            partition_id,
-            &location.path,
-            host,
-            port,
-            flight_transport,
-        )
+        .fetch_partition(&metadata.id, partition_id, &location.path, 
prefer_flight)
         .await
 }
 
-async fn fetch_partition_local(
+fn fetch_partition_local(
     location: &PartitionLocation,
+    sort_shuffle_enabled: bool,
 ) -> result::Result<SendableRecordBatchStream, BallistaError> {
     let path = &location.path;
     let metadata = &location.executor_meta;
     let partition_id = &location.partition_id;
     let data_path = std::path::Path::new(path);
 
+    // TODO: we check if file is there then we open it alter
+    //       replace this check with open, and check for error
+    //
     // Check if this is a sort-based shuffle output (has index file)
-    if is_sort_shuffle_output(data_path) {
+    if sort_shuffle_enabled && is_sort_shuffle_output(data_path) {
+        // note: in some cases sort shuffle is not going to be used
+        //       even its enabled. thus we need to check if there is
+        //       sort shuffle file index
         debug!(
             "Reading sort-based shuffle for partition {} from {:?}",
             partition_id.partition_id, data_path
@@ -622,7 +568,8 @@ fn fetch_partition_local_inner(
     let file = File::open(path).map_err(|e| {
         BallistaError::General(format!("Failed to open partition file at 
{path}: {e:?}"))
     })?;
-    let file = BufReader::new(file);
+    // TODO: make this configurable
+    let file = BufReader::with_capacity(256 * 1024, file);
     // Safety: setting `skip_validation` requires `unsafe`, user assures data 
is valid
     let reader = unsafe {
         StreamReader::try_new(file, None)
@@ -637,14 +584,6 @@ fn fetch_partition_local_inner(
     Ok(reader)
 }
 
-async fn fetch_partition_object_store(
-    _location: &PartitionLocation,
-) -> result::Result<SendableRecordBatchStream, BallistaError> {
-    Err(BallistaError::NotImplemented(
-        "Should not use ObjectStorePartitionReader".to_string(),
-    ))
-}
-
 struct CoalescedShuffleReaderStream {
     schema: SchemaRef,
     input: SendableRecordBatchStream,
@@ -1120,16 +1059,10 @@ mod tests {
             partition_num,
             file_path.to_str().unwrap().to_string(),
         );
+        let config = SessionConfig::new_with_ballista()
+            
.with_ballista_shuffle_reader_maximum_concurrent_requests(max_request_num);
 
-        let response_receiver = send_fetch_partitions(
-            partition_locations,
-            max_request_num,
-            4 * 1024 * 1024,
-            false,
-            true,
-            None,
-            false,
-        );
+        let response_receiver = send_fetch_partitions(partition_locations, 
&config);
 
         let stream = RecordBatchStreamAdapter::new(
             Arc::new(schema),
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 2c777097c..f8292645d 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -16,7 +16,7 @@
 // under the License.
 
 use crate::config::{
-    BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME,
+    BALLISTA_CLIENT_USE_TLS, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, 
BALLISTA_JOB_NAME,
     BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ, 
BALLISTA_SHUFFLE_READER_MAX_REQUESTS,
     BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT, 
BALLISTA_STANDALONE_PARALLELISM,
     BallistaConfig,
@@ -233,6 +233,9 @@ pub trait SessionConfigExt {
 
     /// Get whether to use TLS for executor connections
     fn ballista_use_tls(&self) -> bool;
+
+    /// Is short shuffle used
+    fn ballista_sort_shuffle_enabled(&self) -> bool;
 }
 
 /// [SessionConfigHelperExt] is set of [SessionConfig] extension methods
@@ -392,10 +395,8 @@ impl SessionConfigExt for SessionConfig {
         self.options()
             .extensions
             .get::<BallistaConfig>()
-            .map(|c| c.default_grpc_client_max_message_size())
-            .unwrap_or_else(|| {
-                
BallistaConfig::default().default_grpc_client_max_message_size()
-            })
+            .map(|c| c.grpc_client_max_message_size())
+            .unwrap_or_else(|| 
BallistaConfig::default().grpc_client_max_message_size())
     }
 
     fn with_ballista_job_name(self, job_name: &str) -> Self {
@@ -435,6 +436,14 @@ impl SessionConfigExt for SessionConfig {
             })
     }
 
+    fn ballista_sort_shuffle_enabled(&self) -> bool {
+        self.options()
+            .extensions
+            .get::<BallistaConfig>()
+            .map(|c| c.shuffle_sort_based_enabled())
+            .unwrap_or_else(|| 
BallistaConfig::default().shuffle_sort_based_enabled())
+    }
+
     fn with_ballista_shuffle_reader_maximum_concurrent_requests(
         self,
         max_requests: usize,
@@ -538,13 +547,20 @@ impl SessionConfigExt for SessionConfig {
     }
 
     fn with_ballista_use_tls(self, use_tls: bool) -> Self {
-        self.with_extension(Arc::new(BallistaUseTls(use_tls)))
+        if self.options().extensions.get::<BallistaConfig>().is_some() {
+            self.set_bool(BALLISTA_CLIENT_USE_TLS, use_tls)
+        } else {
+            self.with_option_extension(BallistaConfig::default())
+                .set_bool(BALLISTA_CLIENT_USE_TLS, use_tls)
+        }
     }
 
     fn ballista_use_tls(&self) -> bool {
-        self.get_extension::<BallistaUseTls>()
-            .map(|ext| ext.0)
-            .unwrap_or(false)
+        self.options()
+            .extensions
+            .get::<BallistaConfig>()
+            .map(|c| c.client_use_tls())
+            .unwrap_or_else(|| BallistaConfig::default().client_use_tls())
     }
 }
 
@@ -746,10 +762,6 @@ impl BallistaConfigGrpcEndpoint {
     }
 }
 
-/// Wrapper for cluster-wide TLS configuration
-#[derive(Clone, Copy)]
-pub struct BallistaUseTls(pub bool);
-
 #[derive(Debug)]
 struct BallistaCacheFactory;
 
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 0d6a3d833..f71c74d72 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -63,19 +63,23 @@ pub struct GrpcClientConfig {
     pub tcp_keepalive_seconds: u64,
     /// HTTP/2 keep-alive ping interval in seconds
     pub http2_keepalive_interval_seconds: u64,
+    /// Should client use tls
+    pub use_tls: bool,
+    /// Returns the maximum message size for gRPC clients in bytes.
+    pub max_message_size: usize,
 }
 
 impl From<&BallistaConfig> for GrpcClientConfig {
     fn from(config: &BallistaConfig) -> Self {
         Self {
-            connect_timeout_seconds: 
config.default_grpc_client_connect_timeout_seconds()
-                as u64,
-            timeout_seconds: config.default_grpc_client_timeout_seconds() as 
u64,
-            tcp_keepalive_seconds: 
config.default_grpc_client_tcp_keepalive_seconds()
-                as u64,
+            connect_timeout_seconds: 
config.grpc_client_connect_timeout_seconds() as u64,
+            timeout_seconds: config.grpc_client_timeout_seconds() as u64,
+            tcp_keepalive_seconds: config.grpc_client_tcp_keepalive_seconds() 
as u64,
             http2_keepalive_interval_seconds: config
-                .default_grpc_client_http2_keepalive_interval_seconds()
+                .grpc_client_http2_keepalive_interval_seconds()
                 as u64,
+            use_tls: config.client_use_tls(),
+            max_message_size: config.grpc_client_max_message_size(),
         }
     }
 }
@@ -87,6 +91,8 @@ impl Default for GrpcClientConfig {
             timeout_seconds: 20,
             tcp_keepalive_seconds: 3600,
             http2_keepalive_interval_seconds: 300,
+            use_tls: false,
+            max_message_size: 16 * 1024 * 1024,
         }
     }
 }
@@ -312,19 +318,19 @@ mod tests {
         // Verify the conversion picks up the right values
         assert_eq!(
             grpc_config.connect_timeout_seconds,
-            ballista_config.default_grpc_client_connect_timeout_seconds() as 
u64
+            ballista_config.grpc_client_connect_timeout_seconds() as u64
         );
         assert_eq!(
             grpc_config.timeout_seconds,
-            ballista_config.default_grpc_client_timeout_seconds() as u64
+            ballista_config.grpc_client_timeout_seconds() as u64
         );
         assert_eq!(
             grpc_config.tcp_keepalive_seconds,
-            ballista_config.default_grpc_client_tcp_keepalive_seconds() as u64
+            ballista_config.grpc_client_tcp_keepalive_seconds() as u64
         );
         assert_eq!(
             grpc_config.http2_keepalive_interval_seconds,
-            
ballista_config.default_grpc_client_http2_keepalive_interval_seconds() as u64
+            ballista_config.grpc_client_http2_keepalive_interval_seconds() as 
u64
         );
     }
 
@@ -335,6 +341,8 @@ mod tests {
             timeout_seconds: 30,
             tcp_keepalive_seconds: 1800,
             http2_keepalive_interval_seconds: 150,
+            use_tls: false,
+            max_message_size: 16 * 1024 * 1024,
         };
         let result = create_grpc_client_endpoint("http://localhost:50051";, 
Some(&config));
         assert!(result.is_ok());
diff --git a/examples/examples/standalone-substrait.rs 
b/examples/examples/standalone-substrait.rs
index 7e8b2036c..95caa8370 100644
--- a/examples/examples/standalone-substrait.rs
+++ b/examples/examples/standalone-substrait.rs
@@ -416,8 +416,6 @@ impl SubstraitSchedulerClient {
                 &metadata.id,
                 &partition_id.into(),
                 &location.path,
-                host,
-                port,
                 flight_transport,
             )
             .await


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to