This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new ac49c18a0 feat: Add arrow flight proxy to scheduler (#1351)
ac49c18a0 is described below

commit ac49c18a00322537c4be52f41716895b2faa514d
Author: Sebastian Eckweiler <[email protected]>
AuthorDate: Sat Jan 24 09:15:58 2026 +0100

    feat: Add arrow flight proxy to scheduler (#1351)
---
 ballista/core/proto/ballista.proto                 |   4 +
 .../core/src/execution_plans/distributed_query.rs  | 141 +++++++++++++++-
 ballista/core/src/serde/generated/ballista.rs      |  12 ++
 ballista/scheduler/scheduler_config_spec.toml      |   2 +-
 ballista/scheduler/src/config.rs                   |   4 +-
 ballista/scheduler/src/flight_proxy_service.rs     | 188 +++++++++++++++++++++
 ballista/scheduler/src/lib.rs                      |   2 +
 ballista/scheduler/src/scheduler_process.rs        |  21 +++
 ballista/scheduler/src/scheduler_server/grpc.rs    |  18 +-
 9 files changed, 381 insertions(+), 11 deletions(-)

diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index 3f28cd26e..19e02e501 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -658,6 +658,10 @@ message JobStatus {
 
 message GetJobStatusResult {
   JobStatus status = 1;
+  oneof flight_proxy {
+      bool local = 2;
+      string external = 3;
+  }
 }
 
 message FilePartitionMetadata {
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index ad0b16487..e79edad1d 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -17,12 +17,13 @@
 
 use crate::client::BallistaClient;
 use crate::config::BallistaConfig;
-use crate::serde::protobuf::SuccessfulJob;
+use crate::serde::protobuf::get_job_status_result::FlightProxy;
 use crate::serde::protobuf::{
     ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair,
     PartitionLocation, execute_query_params::Query, execute_query_result, 
job_status,
     scheduler_grpc_client::SchedulerGrpcClient,
 };
+use crate::serde::protobuf::{ExecutorMetadata, SuccessfulJob};
 use crate::utils::{GrpcClientConfig, create_grpc_client_connection};
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::arrow::error::ArrowError;
@@ -49,6 +50,7 @@ use std::fmt::Debug;
 use std::marker::PhantomData;
 use std::sync::Arc;
 use std::time::Duration;
+use url::Url;
 
 /// This operator sends a logical plan to a Ballista scheduler for execution 
and
 /// polls the scheduler until the query is complete and then fetches the 
resulting
@@ -295,7 +297,7 @@ async fn execute_query(
 
     info!("Connecting to Ballista scheduler at {scheduler_url}");
     // TODO reuse the scheduler to avoid connecting to the Ballista scheduler 
again and again
-    let connection = create_grpc_client_connection(scheduler_url, &grpc_config)
+    let connection = create_grpc_client_connection(scheduler_url.clone(), 
&grpc_config)
         .await
         .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
 
@@ -327,7 +329,10 @@ async fn execute_query(
     let mut prev_status: Option<job_status::Status> = None;
 
     loop {
-        let GetJobStatusResult { status } = scheduler
+        let GetJobStatusResult {
+            status,
+            flight_proxy,
+        } = scheduler
             .get_job_status(GetJobStatusParams {
                 job_id: job_id.clone(),
             })
@@ -403,8 +408,14 @@ async fn execute_query(
                 // This could be added in a future enhancement by wrapping the 
stream.
 
                 let streams = partition_location.into_iter().map(move 
|partition| {
-                    let f = fetch_partition(partition, max_message_size, true)
-                        .map_err(|e| ArrowError::ExternalError(Box::new(e)));
+                    let f = fetch_partition(
+                        partition,
+                        max_message_size,
+                        true,
+                        scheduler_url.clone(),
+                        flight_proxy.clone(),
+                    )
+                    .map_err(|e| ArrowError::ExternalError(Box::new(e)));
 
                     futures::stream::once(f).try_flatten()
                 });
@@ -415,22 +426,75 @@ async fn execute_query(
     }
 }
 
+fn get_client_host_port(
+    executor_metadata: &ExecutorMetadata,
+    scheduler_url: &str,
+    flight_proxy: &Option<FlightProxy>,
+) -> Result<(String, u16)> {
+    fn split_host_port(address: &str) -> Result<(String, u16)> {
+        let url: Url = address.parse().map_err(|e| {
+            DataFusionError::Execution(format!(
+                "Cannot parse host:port in {address:?}: {e}"
+            ))
+        })?;
+        let host = url
+            .host_str()
+            .ok_or(DataFusionError::Execution(format!(
+                "No host in {address:?}"
+            )))?
+            .to_string();
+        let port: u16 = url.port().ok_or(DataFusionError::Execution(format!(
+            "No port in {address:?}"
+        )))?;
+        Ok((host, port))
+    }
+
+    match flight_proxy {
+        Some(FlightProxy::External(address)) => {
+            debug!("Fetching results from external flight proxy: {}", address);
+            split_host_port(format!("http://{address}";).as_str())
+        }
+        Some(FlightProxy::Local(true)) => {
+            debug!("Fetching results from scheduler: {}", scheduler_url);
+            split_host_port(scheduler_url)
+        }
+        Some(FlightProxy::Local(false)) | None => {
+            debug!(
+                "Fetching results from executor: {}:{}",
+                executor_metadata.host, executor_metadata.port
+            );
+            Ok((
+                executor_metadata.host.clone(),
+                executor_metadata.port as u16,
+            ))
+        }
+    }
+}
+
 async fn fetch_partition(
     location: PartitionLocation,
     max_message_size: usize,
     flight_transport: bool,
+    scheduler_url: String,
+    flight_proxy: Option<FlightProxy>,
 ) -> Result<SendableRecordBatchStream> {
     let metadata = location.executor_meta.ok_or_else(|| {
         DataFusionError::Internal("Received empty executor 
metadata".to_owned())
     })?;
+
     let partition_id = location.partition_id.ok_or_else(|| {
         DataFusionError::Internal("Received empty partition id".to_owned())
     })?;
     let host = metadata.host.as_str();
     let port = metadata.port as u16;
-    let mut ballista_client = BallistaClient::try_new(host, port, 
max_message_size)
-        .await
-        .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
+    let (client_host, client_port) =
+        get_client_host_port(&metadata, &scheduler_url, &flight_proxy)?;
+
+    let mut ballista_client =
+        BallistaClient::try_new(client_host.as_str(), client_port, 
max_message_size)
+            .await
+            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
     ballista_client
         .fetch_partition(
             &metadata.id,
@@ -443,3 +507,64 @@ async fn fetch_partition(
         .await
         .map_err(|e| DataFusionError::External(Box::new(e)))
 }
+
+#[cfg(test)]
+mod test {
+    use crate::execution_plans::distributed_query::get_client_host_port;
+    use crate::serde::protobuf::ExecutorMetadata;
+    use crate::serde::protobuf::get_job_status_result::FlightProxy;
+
+    #[test]
+    fn test_client_host_port() {
+        let scheduler_host = "scheduler";
+        let scheduler_port: u16 = 5000;
+
+        let scheduler_url = 
format!("http://{scheduler_host}:{scheduler_port}";);
+        let executor = ExecutorMetadata {
+            id: "test".to_string(),
+            host: "executor".to_string(),
+            port: 12345,
+            grpc_port: 1,
+            specification: None,
+        };
+
+        // no flight proxy -> client should fetch results from executor
+        assert_eq!(
+            get_client_host_port(&executor, &scheduler_url, &None).unwrap(),
+            (executor.host.clone(), executor.port as u16)
+        );
+
+        // same, no flight proxy
+        assert_eq!(
+            get_client_host_port(
+                &executor,
+                &scheduler_url,
+                &Some(FlightProxy::Local(false))
+            )
+            .unwrap(),
+            (executor.host.clone(), executor.port as u16)
+        );
+
+        // embedded flight proxy on scheduler
+        assert_eq!(
+            get_client_host_port(
+                &executor,
+                &scheduler_url,
+                &Some(FlightProxy::Local(true))
+            )
+            .unwrap(),
+            (scheduler_host.to_string(), scheduler_port)
+        );
+
+        // external proxy
+        assert_eq!(
+            get_client_host_port(
+                &executor,
+                &scheduler_url,
+                &Some(FlightProxy::External("proxy:1234".to_string()))
+            )
+            .unwrap(),
+            ("proxy".to_string(), 1234_u16)
+        );
+    }
+}
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index 910b58f2d..d8b2d5ce0 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -1009,6 +1009,18 @@ pub mod job_status {
 pub struct GetJobStatusResult {
     #[prost(message, optional, tag = "1")]
     pub status: ::core::option::Option<JobStatus>,
+    #[prost(oneof = "get_job_status_result::FlightProxy", tags = "2, 3")]
+    pub flight_proxy: 
::core::option::Option<get_job_status_result::FlightProxy>,
+}
+/// Nested message and enum types in `GetJobStatusResult`.
+pub mod get_job_status_result {
+    #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
+    pub enum FlightProxy {
+        #[prost(bool, tag = "2")]
+        Local(bool),
+        #[prost(string, tag = "3")]
+        External(::prost::alloc::string::String),
+    }
 }
 #[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
 pub struct FilePartitionMetadata {
diff --git a/ballista/scheduler/scheduler_config_spec.toml 
b/ballista/scheduler/scheduler_config_spec.toml
index 20bceb5f2..5659d25d8 100644
--- a/ballista/scheduler/scheduler_config_spec.toml
+++ b/ballista/scheduler/scheduler_config_spec.toml
@@ -27,7 +27,7 @@ doc = "Print version of this executable"
 [[param]]
 name = "advertise_flight_sql_endpoint"
 type = "String"
-doc = "Route for proxying flight results via scheduler. Should be of the form 
'IP:PORT'"
+doc = "Route for proxying flight results via scheduler. Use 'HOST:PORT' to let 
clients fetch results from the specified address. If empty a flight proxy will 
be started on the scheduler host and port."
 
 [[param]]
 abbr = "n"
diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index 7bafc59eb..bbb82f5f5 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -41,7 +41,9 @@ pub struct Config {
     /// Route for proxying flight results via scheduler (IP:PORT format).
     #[arg(
         long,
-        help = "Route for proxying flight results via scheduler. Should be of 
the form 'IP:PORT"
+        num_args = 0..=1,
+        default_missing_value = "",
+        help = "Route for proxying flight results via scheduler. Use 
'HOST:PORT' to let clients fetch results from the specified address. If empty a 
flight proxy will be started on the scheduler host and port."
     )]
     pub advertise_flight_sql_endpoint: Option<String>,
     /// Namespace for the ballista cluster.
diff --git a/ballista/scheduler/src/flight_proxy_service.rs 
b/ballista/scheduler/src/flight_proxy_service.rs
new file mode 100644
index 000000000..bc5c63153
--- /dev/null
+++ b/ballista/scheduler/src/flight_proxy_service.rs
@@ -0,0 +1,188 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_flight::flight_service_client::FlightServiceClient;
+use arrow_flight::flight_service_server::FlightService;
+use arrow_flight::{
+    Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, 
FlightInfo,
+    HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, 
Ticket,
+};
+use ballista_core::error::BallistaError;
+use ballista_core::serde::decode_protobuf;
+use ballista_core::serde::scheduler::Action as BallistaAction;
+use ballista_core::utils::{GrpcClientConfig, create_grpc_client_connection};
+
+use futures::{Stream, TryFutureExt};
+use log::debug;
+use std::pin::Pin;
+use tonic::{Request, Response, Status, Streaming};
+
+/// Service implementing a proxy from scheduler to executor Apache Arrow 
Flight Protocol
+///
+/// The proxy only implements the FlightService::do_get api and forwards the 
requests
+/// to the respective executors.
+///
+#[derive(Clone)]
+pub struct BallistaFlightProxyService {
+    max_decoding_message_size: usize,
+    max_encoding_message_size: usize,
+}
+
+impl BallistaFlightProxyService {
+    pub fn new(
+        max_decoding_message_size: usize,
+        max_encoding_message_size: usize,
+    ) -> Self {
+        Self {
+            max_decoding_message_size,
+            max_encoding_message_size,
+        }
+    }
+}
+
+type BoxedFlightStream<T> =
+    Pin<Box<dyn Stream<Item = Result<T, Status>> + Send + 'static>>;
+
+#[tonic::async_trait]
+impl FlightService for BallistaFlightProxyService {
+    type DoActionStream = BoxedFlightStream<arrow_flight::Result>;
+    type DoExchangeStream = BoxedFlightStream<FlightData>;
+    type DoGetStream = BoxedFlightStream<FlightData>;
+    type DoPutStream = BoxedFlightStream<PutResult>;
+    type HandshakeStream = BoxedFlightStream<HandshakeResponse>;
+    type ListActionsStream = BoxedFlightStream<ActionType>;
+    type ListFlightsStream = BoxedFlightStream<FlightInfo>;
+    async fn handshake(
+        &self,
+        _request: Request<Streaming<HandshakeRequest>>,
+    ) -> Result<Response<Self::HandshakeStream>, Status> {
+        Err(Status::unimplemented("handshake"))
+    }
+
+    async fn list_flights(
+        &self,
+        _request: Request<Criteria>,
+    ) -> Result<Response<Self::ListFlightsStream>, Status> {
+        Err(Status::unimplemented("list_flights"))
+    }
+
+    async fn get_flight_info(
+        &self,
+        _request: Request<FlightDescriptor>,
+    ) -> Result<Response<FlightInfo>, Status> {
+        Err(Status::unimplemented("get_flight_info"))
+    }
+
+    async fn poll_flight_info(
+        &self,
+        _request: Request<FlightDescriptor>,
+    ) -> Result<Response<PollInfo>, Status> {
+        Err(Status::unimplemented("poll_flight_info"))
+    }
+
+    async fn get_schema(
+        &self,
+        _request: Request<FlightDescriptor>,
+    ) -> Result<Response<SchemaResult>, Status> {
+        Err(Status::unimplemented("get_schema"))
+    }
+
+    async fn do_get(
+        &self,
+        request: Request<Ticket>,
+    ) -> Result<Response<Self::DoGetStream>, Status> {
+        let ticket = request.into_inner();
+
+        let action =
+            decode_protobuf(&ticket.ticket).map_err(|e| 
from_ballista_err(&e))?;
+
+        match &action {
+            BallistaAction::FetchPartition {
+                host, port, job_id, ..
+            } => {
+                debug!("Fetching results for job id: {job_id} from 
{host}:{port}");
+                let mut client = get_flight_client(
+                    host,
+                    *port,
+                    self.max_decoding_message_size,
+                    self.max_encoding_message_size,
+                )
+                .map_err(|e| from_ballista_err(&e))
+                .await?;
+                client
+                    .do_get(Request::new(ticket))
+                    .await
+                    .map(|r| Response::new(Box::pin(r.into_inner()) as 
Self::DoGetStream))
+            }
+        }
+    }
+
+    async fn do_put(
+        &self,
+        _request: Request<Streaming<FlightData>>,
+    ) -> Result<Response<Self::DoPutStream>, Status> {
+        Err(Status::unimplemented("do_put"))
+    }
+
+    async fn do_exchange(
+        &self,
+        _request: Request<Streaming<FlightData>>,
+    ) -> Result<Response<Self::DoExchangeStream>, Status> {
+        Err(Status::unimplemented("do_exchange"))
+    }
+
+    async fn do_action(
+        &self,
+        _request: Request<Action>,
+    ) -> Result<Response<Self::DoActionStream>, Status> {
+        Err(Status::unimplemented("do_action"))
+    }
+
+    async fn list_actions(
+        &self,
+        _request: Request<Empty>,
+    ) -> Result<Response<Self::ListActionsStream>, Status> {
+        Err(Status::unimplemented("list_actions"))
+    }
+}
+
+fn from_ballista_err(e: &ballista_core::error::BallistaError) -> Status {
+    Status::internal(format!("Ballista Error: {e:?}"))
+}
+
+async fn get_flight_client(
+    host: &str,
+    port: u16,
+    max_decoding_message_size: usize,
+    max_encoding_message_size: usize,
+) -> Result<FlightServiceClient<tonic::transport::channel::Channel>, 
BallistaError> {
+    let addr = format!("http://{host}:{port}";);
+    let grpc_config = GrpcClientConfig::default();
+    let connection = create_grpc_client_connection(addr.clone(), &grpc_config)
+        .await
+        .map_err(|e| {
+            BallistaError::GrpcConnectionError(format!(
+                "Error connecting to Ballista scheduler or executor at {addr}: 
{e:?}"
+            ))
+        })?;
+    let flight_client = FlightServiceClient::new(connection)
+        .max_decoding_message_size(max_decoding_message_size)
+        .max_encoding_message_size(max_encoding_message_size);
+
+    debug!("FlightProxyService connected: {flight_client:?}");
+    Ok(flight_client)
+}
diff --git a/ballista/scheduler/src/lib.rs b/ballista/scheduler/src/lib.rs
index 6b61c8235..401df59e3 100644
--- a/ballista/scheduler/src/lib.rs
+++ b/ballista/scheduler/src/lib.rs
@@ -41,6 +41,8 @@ pub mod standalone;
 /// Scheduler state management.
 pub mod state;
 
+mod flight_proxy_service;
+
 /// Test utilities for scheduler testing.
 #[cfg(test)]
 pub mod test_utils;
diff --git a/ballista/scheduler/src/scheduler_process.rs 
b/ballista/scheduler/src/scheduler_process.rs
index 361ade55f..679d2c411 100644
--- a/ballista/scheduler/src/scheduler_process.rs
+++ b/ballista/scheduler/src/scheduler_process.rs
@@ -15,6 +15,9 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::flight_proxy_service::BallistaFlightProxyService;
+
+use arrow_flight::flight_service_server::FlightServiceServer;
 use ballista_core::BALLISTA_VERSION;
 use ballista_core::error::BallistaError;
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer;
@@ -96,6 +99,24 @@ pub async fn start_grpc_service<
     let mut tonic_builder = RoutesBuilder::default();
     tonic_builder.add_service(scheduler_grpc_server);
 
+    match &config.advertise_flight_sql_endpoint {
+        Some(proxy) if proxy.is_empty() => {
+            info!("Adding embedded flight proxy service on scheduler");
+            let flight_proxy = 
FlightServiceServer::new(BallistaFlightProxyService::new(
+                config.grpc_server_max_encoding_message_size as usize,
+                config.grpc_server_max_decoding_message_size as usize,
+            ))
+            .max_decoding_message_size(
+                config.grpc_server_max_decoding_message_size as usize,
+            )
+            .max_encoding_message_size(
+                config.grpc_server_max_encoding_message_size as usize,
+            );
+            tonic_builder.add_service(flight_proxy);
+        }
+        _ => {}
+    }
+
     #[cfg(feature = "keda-scaler")]
     tonic_builder.add_service(ExternalScalerServer::new(scheduler.clone()));
 
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 4d2148aca..f50ec7f1a 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -48,6 +48,7 @@ use std::ops::Deref;
 use crate::cluster::{bind_task_bias, bind_task_round_robin};
 use crate::config::TaskDistributionPolicy;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
+use ballista_core::serde::protobuf::get_job_status_result::FlightProxy;
 use std::time::{SystemTime, UNIX_EPOCH};
 use tonic::{Request, Response, Status};
 
@@ -464,8 +465,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
     ) -> Result<Response<GetJobStatusResult>, Status> {
         let job_id = request.into_inner().job_id;
         trace!("Received get_job_status request for job {}", job_id);
+
+        let flight_proxy =
+            self.state
+                .config
+                .advertise_flight_sql_endpoint
+                .clone()
+                .map(|s| match s {
+                    s if s.is_empty() => FlightProxy::Local(true),
+                    s => FlightProxy::External(s),
+                });
+
         match self.state.task_manager.get_job_status(&job_id).await {
-            Ok(status) => Ok(Response::new(GetJobStatusResult { status })),
+            Ok(status) => Ok(Response::new(GetJobStatusResult {
+                status,
+                flight_proxy,
+            })),
+
             Err(e) => {
                 let msg = format!("Error getting status for job {job_id}: 
{e:?}");
                 error!("{msg}");


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to