This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 3696b9b21 feat: enable scheduler rest api by default (#1506)
3696b9b21 is described below

commit 3696b9b21d0d56feb0a518f51fdedf6de664e65e
Author: Marko Milenković <[email protected]>
AuthorDate: Sun Mar 15 16:19:11 2026 +0000

    feat: enable scheduler rest api by default (#1506)
---
 ballista/scheduler/Cargo.toml                    |   2 +-
 ballista/scheduler/src/api/handlers.rs           | 119 ++++++++++++++++-------
 ballista/scheduler/src/api/mod.rs                |  88 +++++++++--------
 ballista/scheduler/src/api/{mod.rs => routes.rs} |   4 +-
 ballista/scheduler/src/config.rs                 |  16 +++
 ballista/scheduler/src/lib.rs                    |   1 -
 ballista/scheduler/src/scheduler_process.rs      |  48 +++++----
 7 files changed, 182 insertions(+), 96 deletions(-)

diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml
index bed395518..9bf90999e 100644
--- a/ballista/scheduler/Cargo.toml
+++ b/ballista/scheduler/Cargo.toml
@@ -34,7 +34,7 @@ required-features = ["build-binary"]
 
 [features]
 build-binary = ["clap", "tracing-subscriber", "tracing-appender", "tracing", 
"ballista-core/build-binary"]
-default = ["build-binary", "substrait"]
+default = ["build-binary", "substrait", "rest-api"]
 # job info can cache stage plans, in some cases where
 # task plans can be re-computed, cache behavior may need to be disabled.
 disable-stage-plan-cache = []
diff --git a/ballista/scheduler/src/api/handlers.rs 
b/ballista/scheduler/src/api/handlers.rs
index a7ade706f..3cf1f79a6 100644
--- a/ballista/scheduler/src/api/handlers.rs
+++ b/ballista/scheduler/src/api/handlers.rs
@@ -10,10 +10,10 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-use crate::scheduler_server::SchedulerServer;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::state::execution_graph::ExecutionStage;
 use crate::state::execution_graph_dot::ExecutionGraphDot;
+use crate::{api::SchedulerErrorResponse, scheduler_server::SchedulerServer};
 use axum::{
     Json,
     extract::{Path, State},
@@ -38,6 +38,19 @@ use std::time::Duration;
 struct SchedulerStateResponse {
     started: u128,
     version: &'static str,
+    substrait_support: bool,
+    keda_support: bool,
+    prometheus_support: bool,
+    graphviz_support: bool,
+    spark_support: bool,
+    scheduling_policy: String,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    advertise_flight_sql_endpoint: Option<String>,
+}
+
+#[derive(Debug, serde::Serialize)]
+struct SchedulerVersionResponse {
+    version: &'static str,
 }
 #[derive(Debug, serde::Serialize)]
 pub struct ExecutorMetaResponse {
@@ -89,6 +102,24 @@ pub async fn get_scheduler_state<
     let response = SchedulerStateResponse {
         started: data_server.start_time,
         version: BALLISTA_VERSION,
+        substrait_support: cfg!(feature = "substrait"),
+        keda_support: cfg!(feature = "keda-scaler"),
+        prometheus_support: cfg!(feature = "prometheus-metrics"),
+        graphviz_support: cfg!(feature = "graphviz-support"),
+        spark_support: cfg!(feature = "spark-compat"),
+        scheduling_policy: 
data_server.state.config.scheduling_policy.to_string(),
+        advertise_flight_sql_endpoint: data_server
+            .state
+            .config
+            .advertise_flight_sql_endpoint
+            .clone(),
+    };
+    Json(response)
+}
+
+pub async fn get_scheduler_version() -> impl IntoResponse {
+    let response = SchedulerVersionResponse {
+        version: BALLISTA_VERSION,
     };
     Json(response)
 }
@@ -122,15 +153,13 @@ pub async fn get_jobs<
     U: AsExecutionPlan + Send + Sync + 'static,
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
-) -> Result<impl IntoResponse, StatusCode> {
-    // TODO: Display last seen information in UI
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
     let state = &data_server.state;
 
-    let jobs = state
-        .task_manager
-        .get_jobs()
-        .await
-        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+    let jobs = state.task_manager.get_jobs().await.map_err(|e| {
+        tracing::error!("Error occurred while getting jobs, reason: {e:?}");
+        SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
+    })?;
 
     let jobs: Vec<JobResponse> = jobs
         .iter()
@@ -166,17 +195,17 @@ pub async fn get_job<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path(job_id): Path<String>,
-) -> Result<impl IntoResponse, StatusCode> {
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
     let graph = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
         .map_err(|err| {
-            tracing::error!("Error occurred while getting the execution graph 
for job '{job_id}': {err:?}");
-            StatusCode::INTERNAL_SERVER_ERROR
+            tracing::error!("Error occurred while getting the execution graph 
for job '{job_id}' reason: {err:?}");
+            
SchedulerErrorResponse::with_error(StatusCode::INTERNAL_SERVER_ERROR, 
format!("Error occurred while getting the execution graph for job '{job_id}'"))
         })?
-        .ok_or(StatusCode::NOT_FOUND)?;
+        .ok_or_else(|| SchedulerErrorResponse::new(StatusCode::NOT_FOUND))?;
     let stage_plan = format!("{:?}", graph);
     let job = graph.as_ref();
     let (plain_status, job_status) =
@@ -207,7 +236,7 @@ pub async fn cancel_job<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path(job_id): Path<String>,
-) -> Result<impl IntoResponse, StatusCode> {
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
     // 404 if the job doesn't exist
     let job_status = data_server
         .state
@@ -216,9 +245,12 @@ pub async fn cancel_job<
         .await
         .map_err(|err| {
             tracing::error!("Error getting job status: {err:?}");
-            StatusCode::INTERNAL_SERVER_ERROR
+            SchedulerErrorResponse::with_error(
+                StatusCode::INTERNAL_SERVER_ERROR,
+                format!("Error getting job status: {err}"),
+            )
         })?
-        .ok_or(StatusCode::NOT_FOUND)?;
+        .ok_or_else(|| SchedulerErrorResponse::new(StatusCode::NOT_FOUND))?;
 
     match &job_status.status {
         None | Some(Status::Queued(_)) | Some(Status::Running(_)) => {
@@ -229,11 +261,13 @@ pub async fn cancel_job<
                     tracing::error!(
                         "Error getting query stage event loop sender: {err:?}"
                     );
-                    StatusCode::INTERNAL_SERVER_ERROR
+                    
SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
                 })?
                 .post_event(QueryStageSchedulerEvent::JobCancel(job_id))
                 .await
-                .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+                .map_err(|_| {
+                    
SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
+                })?;
 
             Ok((
                 StatusCode::OK,
@@ -274,13 +308,19 @@ pub async fn get_query_stages<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path(job_id): Path<String>,
-) -> Result<impl IntoResponse, StatusCode> {
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
+        .map_err(|e| {
+            tracing::error!("Error occurred while getting the query stages for 
job '{job_id}' reason: {e:?}");
+            SchedulerErrorResponse::with_error(
+                StatusCode::INTERNAL_SERVER_ERROR,
+                format!("Error occurred while getting the query stages for job 
'{job_id}'"),
+            )
+        })?
     {
         let stages = graph
             .as_ref()
@@ -332,7 +372,7 @@ pub async fn get_query_stages<
 
         Ok(Json(QueryStagesResponse { stages }))
     } else {
-        Ok(Json(QueryStagesResponse { stages: vec![] }))
+        Err(SchedulerErrorResponse::new(StatusCode::NOT_FOUND))
     }
 }
 
@@ -409,18 +449,24 @@ pub async fn get_job_dot_graph<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path(job_id): Path<String>,
-) -> Result<String, StatusCode> {
+) -> Result<String, SchedulerErrorResponse> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
+        .map_err(|e| {
+            tracing::error!("Error occurred while getting the dot graph for 
job '{job_id}' reason: {e:?}");
+            SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
+        })?
     {
         ExecutionGraphDot::generate(graph.as_ref())
-            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
+            .map_err(|e|  {
+                tracing::error!("Error occurred while getting the dot graph 
for job '{job_id}' reason: {e:?}");
+                SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
+            })
     } else {
-        Ok("Not Found".to_string())
+        Err(SchedulerErrorResponse::new(StatusCode::NOT_FOUND))
     }
 }
 
@@ -430,18 +476,18 @@ pub async fn get_query_stage_dot_graph<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path((job_id, stage_id)): Path<(String, usize)>,
-) -> Result<impl IntoResponse, StatusCode> {
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
     if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
+        .map_err(|_| 
SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR))?
     {
         ExecutionGraphDot::generate_for_query_stage(graph.as_ref(), stage_id)
-            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
+            .map_err(|_| 
SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR))
     } else {
-        Ok("Not Found".to_string())
+        Err(SchedulerErrorResponse::new(StatusCode::NOT_FOUND))
     }
 }
 #[cfg(feature = "graphviz-support")]
@@ -451,8 +497,8 @@ pub async fn get_job_svg_graph<
 >(
     State(data_server): State<Arc<SchedulerServer<T, U>>>,
     Path(job_id): Path<String>,
-) -> Result<impl IntoResponse, StatusCode> {
-    let dot = get_job_dot_graph(State(data_server.clone()), 
Path(job_id)).await?;
+) -> Result<impl IntoResponse, SchedulerErrorResponse> {
+    let dot = get_job_dot_graph(State(data_server.clone()), 
Path(job_id.clone())).await?;
     match graphviz_rust::parse(&dot) {
         Ok(graph) => {
             let result = exec(
@@ -460,7 +506,10 @@ pub async fn get_job_svg_graph<
                 &mut PrinterContext::default(),
                 vec![CommandArg::Format(Format::Svg)],
             )
-            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
+            .map_err(|e| {
+                tracing::error!("Error occurred while getting job svg graph 
for job '{job_id}' reason: {e:?}");
+                SchedulerErrorResponse::new(StatusCode::INTERNAL_SERVER_ERROR)
+            })?;
 
             let svg = String::from_utf8_lossy(&result).to_string();
             Ok(Response::builder()
@@ -468,10 +517,10 @@ pub async fn get_job_svg_graph<
                 .body(svg)
                 .unwrap())
         }
-        Err(_) => Ok(Response::builder()
-            .status(StatusCode::BAD_REQUEST)
-            .body("Cannot parse graph".to_string())
-            .unwrap()),
+        Err(e) => Err(SchedulerErrorResponse::with_error(
+            StatusCode::BAD_REQUEST,
+            e.to_string(),
+        )),
     }
 }
 
diff --git a/ballista/scheduler/src/api/mod.rs 
b/ballista/scheduler/src/api/mod.rs
index 2662e3eea..c02b47fd4 100644
--- a/ballista/scheduler/src/api/mod.rs
+++ b/ballista/scheduler/src/api/mod.rs
@@ -10,48 +10,58 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
+#[cfg(feature = "rest-api")]
 mod handlers;
+#[cfg(feature = "rest-api")]
+mod routes;
+#[cfg(feature = "rest-api")]
+pub use routes::get_routes;
 
-use crate::scheduler_server::SchedulerServer;
-use axum::{Router, routing::get};
-use datafusion_proto::logical_plan::AsLogicalPlan;
-use datafusion_proto::physical_plan::AsExecutionPlan;
-use std::sync::Arc;
+use axum::response::{IntoResponse, Response};
+use axum::{Json, Router};
+use http::StatusCode;
 
-/// All routes configured for rest-api.
-pub fn get_routes<
-    T: AsLogicalPlan + Clone + Send + Sync + 'static,
-    U: AsExecutionPlan + Send + Sync + 'static,
->(
-    scheduler_server: Arc<SchedulerServer<T, U>>,
-) -> Router {
-    let router = Router::new()
-        .route("/api/state", get(handlers::get_scheduler_state::<T, U>))
-        .route("/api/executors", get(handlers::get_executors::<T, U>))
-        .route("/api/jobs", get(handlers::get_jobs::<T, U>))
-        .route(
-            "/api/job/{job_id}",
-            get(handlers::get_job::<T, U>).patch(handlers::cancel_job::<T, U>),
-        )
-        .route(
-            "/api/job/{job_id}/stages",
-            get(handlers::get_query_stages::<T, U>),
-        )
-        .route(
-            "/api/job/{job_id}/dot",
-            get(handlers::get_job_dot_graph::<T, U>),
-        )
-        .route(
-            "/api/job/{job_id}/stage/{stage_id}/dot",
-            get(handlers::get_query_stage_dot_graph::<T, U>),
-        )
-        .route("/api/metrics", get(handlers::get_scheduler_metrics::<T, U>));
+pub(super) fn route_disabled(reason: String) -> Router {
+    Router::new().route(
+        "/api/{*path}",
+        axum::routing::any(|| async {
+            SchedulerErrorResponse::with_error(StatusCode::NOT_FOUND, reason)
+        }),
+    )
+}
+
+#[derive(Debug, serde::Serialize)]
+pub(crate) struct SchedulerErrorResponse {
+    #[serde(skip)]
+    status_code: StatusCode,
+    http_code: u16,
+    reason: Option<&'static str>,
+    #[serde(skip_serializing_if = "Option::is_none")]
+    error: Option<String>,
+}
 
-    #[cfg(feature = "graphviz-support")]
-    let router = router.route(
-        "/api/job/{job_id}/dot_svg",
-        get(handlers::get_job_svg_graph::<T, U>),
-    );
+impl SchedulerErrorResponse {
+    pub(crate) fn new(status_code: StatusCode) -> Self {
+        Self {
+            status_code,
+            reason: status_code.canonical_reason(),
+            http_code: status_code.as_u16(),
+            error: None,
+        }
+    }
+    pub(crate) fn with_error(status_code: StatusCode, error: String) -> Self {
+        Self {
+            status_code,
+            reason: status_code.canonical_reason(),
+            http_code: status_code.as_u16(),
+            error: Some(error),
+        }
+    }
+}
 
-    router.with_state(scheduler_server)
+impl IntoResponse for SchedulerErrorResponse {
+    fn into_response(self) -> Response {
+        let status = self.status_code;
+        (status, Json(self)).into_response()
+    }
 }
diff --git a/ballista/scheduler/src/api/mod.rs 
b/ballista/scheduler/src/api/routes.rs
similarity index 95%
copy from ballista/scheduler/src/api/mod.rs
copy to ballista/scheduler/src/api/routes.rs
index 2662e3eea..9e7996518 100644
--- a/ballista/scheduler/src/api/mod.rs
+++ b/ballista/scheduler/src/api/routes.rs
@@ -10,8 +10,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 
-mod handlers;
-
+use crate::api::handlers;
 use crate::scheduler_server::SchedulerServer;
 use axum::{Router, routing::get};
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -27,6 +26,7 @@ pub fn get_routes<
 ) -> Router {
     let router = Router::new()
         .route("/api/state", get(handlers::get_scheduler_state::<T, U>))
+        .route("/api/version", get(handlers::get_scheduler_version))
         .route("/api/executors", get(handlers::get_executors::<T, U>))
         .route("/api/jobs", get(handlers::get_jobs::<T, U>))
         .route(
diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index 9ddede024..dd73ee9a1 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -190,6 +190,15 @@ pub struct Config {
         help = "The interval to check expired or dead executors"
     )]
     pub expire_dead_executor_interval_seconds: u64,
+
+    #[cfg(feature = "rest-api")]
+    /// Should the rest api be disabled
+    #[arg(
+        long,
+        default_value_t = false,
+        help = "Should the REST API be disabled"
+    )]
+    pub disable_rest_api: bool,
 }
 
 /// Configurations for the ballista scheduler of scheduling jobs and tasks
@@ -245,6 +254,9 @@ pub struct SchedulerConfig {
     pub override_create_grpc_client_endpoint: Option<EndpointOverrideFn>,
     /// Whether to use TLS when connecting to executors (for flight proxy)
     pub use_tls: bool,
+    #[cfg(feature = "rest-api")]
+    /// Should the rest api be disabled
+    pub disable_rest_api: bool,
 }
 
 impl Default for SchedulerConfig {
@@ -273,6 +285,8 @@ impl Default for SchedulerConfig {
             override_physical_codec: None,
             override_create_grpc_client_endpoint: None,
             use_tls: false,
+            #[cfg(feature = "rest-api")]
+            disable_rest_api: false,
         }
     }
 }
@@ -520,6 +534,8 @@ impl TryFrom<Config> for SchedulerConfig {
             override_session_builder: None,
             override_create_grpc_client_endpoint: None,
             use_tls: false,
+            #[cfg(feature = "rest-api")]
+            disable_rest_api: opt.disable_rest_api,
         };
 
         Ok(config)
diff --git a/ballista/scheduler/src/lib.rs b/ballista/scheduler/src/lib.rs
index 401df59e3..533711ecb 100644
--- a/ballista/scheduler/src/lib.rs
+++ b/ballista/scheduler/src/lib.rs
@@ -17,7 +17,6 @@
 
 #![doc = include_str ! ("../README.md")]
 #![warn(missing_docs)]
-#[cfg(feature = "rest-api")]
 /// REST API endpoints for scheduler operations.
 pub mod api;
 /// Cluster management and executor coordination.
diff --git a/ballista/scheduler/src/scheduler_process.rs 
b/ballista/scheduler/src/scheduler_process.rs
index 3e6a97b2f..92f8f1c8e 100644
--- a/ballista/scheduler/src/scheduler_process.rs
+++ b/ballista/scheduler/src/scheduler_process.rs
@@ -15,8 +15,18 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::api::SchedulerErrorResponse;
 use crate::flight_proxy_service::BallistaFlightProxyService;
 
+#[cfg(feature = "rest-api")]
+use crate::api::get_routes;
+use crate::api::route_disabled;
+use crate::cluster::BallistaCluster;
+use crate::config::SchedulerConfig;
+use crate::metrics::default_metrics_collector;
+use crate::scheduler_server::SchedulerServer;
+#[cfg(feature = "keda-scaler")]
+use 
crate::scheduler_server::externalscaler::external_scaler_server::ExternalScalerServer;
 use arrow_flight::flight_service_server::FlightServiceServer;
 use ballista_core::BALLISTA_VERSION;
 use ballista_core::error::BallistaError;
@@ -32,17 +42,6 @@ use http::StatusCode;
 use log::info;
 use std::{net::SocketAddr, sync::Arc};
 use tonic::service::RoutesBuilder;
-
-#[cfg(feature = "rest-api")]
-use crate::api::get_routes;
-use crate::cluster::BallistaCluster;
-use crate::config::SchedulerConfig;
-
-use crate::metrics::default_metrics_collector;
-use crate::scheduler_server::SchedulerServer;
-#[cfg(feature = "keda-scaler")]
-use 
crate::scheduler_server::externalscaler::external_scaler_server::ExternalScalerServer;
-
 /// Creates as initialized scheduler service
 /// without exposing it as a grpc service
 pub async fn create_scheduler<
@@ -130,17 +129,30 @@ pub async fn start_grpc_service<
     tonic_builder.add_service(ExternalScalerServer::new(scheduler.clone()));
 
     let tonic = tonic_builder.routes().into_axum_router();
-    let tonic = tonic.fallback(|| async { (StatusCode::NOT_FOUND, "404 - Not 
Found") });
+
+    // registering default handler for unmatched requests
+    let tonic =
+        tonic.fallback(|| async { 
SchedulerErrorResponse::new(StatusCode::NOT_FOUND) });
 
     #[cfg(feature = "rest-api")]
-    let axum = get_routes(Arc::new(scheduler));
-    #[cfg(feature = "rest-api")]
-    let final_route = axum
-        .merge(tonic)
-        .into_make_service_with_connect_info::<SocketAddr>();
+    let final_route = if config.disable_rest_api {
+        tonic
+            .merge(route_disabled(
+                "REST API has been disabled at startup".to_string(),
+            ))
+            .into_make_service_with_connect_info::<SocketAddr>()
+    } else {
+        let axum = get_routes(Arc::new(scheduler));
+        axum.merge(tonic)
+            .into_make_service_with_connect_info::<SocketAddr>()
+    };
 
     #[cfg(not(feature = "rest-api"))]
-    let final_route = 
tonic.into_make_service_with_connect_info::<SocketAddr>();
+    let final_route = tonic
+        .merge(route_disabled(
+            "REST API has been disabled at compile time".to_string(),
+        ))
+        .into_make_service_with_connect_info::<SocketAddr>();
 
     let listener = tokio::net::TcpListener::bind(&address)
         .await


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to