This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 5e98229e3 feat: job scheduling with push based job status updates 
(#1478)
5e98229e3 is described below

commit 5e98229e3ad1eb58f395c79e0a0dc9e01df43df0
Author: Marko Milenković <[email protected]>
AuthorDate: Mon Mar 2 17:05:16 2026 +0000

    feat: job scheduling with push based job status updates (#1478)
    
    * implement push based job execution
    
    * minor cleanup
    
    * add additional test
    
    * refactor, extract common code to methods
    
    * fix job name issue
    
    * clone subscriber, not to keep awaiting in a lock
    
    * addressing few comments
    
    * remove print
    
    * update sender to use try send
    
    * fix clippy
---
 ballista/client/tests/context_checks.rs            |  58 +++++
 ballista/core/proto/ballista.proto                 |   2 +
 ballista/core/src/config.rs                        |  13 +-
 .../core/src/execution_plans/distributed_query.rs  | 240 ++++++++++++++++---
 ballista/core/src/lib.rs                           |   5 +
 ballista/core/src/serde/generated/ballista.rs      |  89 +++++++
 ballista/scheduler/src/cluster/memory.rs           |  64 +++++-
 ballista/scheduler/src/cluster/mod.rs              |   9 +-
 ballista/scheduler/src/cluster/test_util/mod.rs    |   2 +-
 ballista/scheduler/src/scheduler_server/event.rs   |   4 +-
 ballista/scheduler/src/scheduler_server/grpc.rs    | 255 ++++++++++++++-------
 ballista/scheduler/src/scheduler_server/mod.rs     | 223 +++++++++++++++++-
 .../src/scheduler_server/query_stage_scheduler.rs  |  37 ++-
 ballista/scheduler/src/state/mod.rs                |   3 +
 ballista/scheduler/src/state/task_manager.rs       |   8 +-
 ballista/scheduler/src/test_utils.rs               |  17 +-
 16 files changed, 897 insertions(+), 132 deletions(-)

diff --git a/ballista/client/tests/context_checks.rs 
b/ballista/client/tests/context_checks.rs
index 908e9c23d..0faf3b9b7 100644
--- a/ballista/client/tests/context_checks.rs
+++ b/ballista/client/tests/context_checks.rs
@@ -1047,4 +1047,62 @@ mod supported {
         ];
         assert_batches_eq!(expected, &result);
     }
+
+    #[rstest]
+    #[case::standalone(standalone_context())]
+    #[case::remote(remote_context())]
+    #[tokio::test]
+    async fn should_force_client_pull(
+        #[future(awt)]
+        #[case]
+        ctx: SessionContext,
+        test_data: String,
+    ) -> datafusion::error::Result<()> {
+        ctx.register_parquet(
+            "test",
+            &format!("{test_data}/alltypes_plain.parquet"),
+            Default::default(),
+        )
+        .await?;
+
+        ctx.sql("SET ballista.client.pull = true")
+            .await?
+            .show()
+            .await?;
+
+        let result = ctx
+            .sql("select name, value from information_schema.df_settings where 
name like 'ballista.client.pull' order by name limit 1")
+            .await?
+            .collect()
+            .await?;
+
+        let expected = [
+            "+----------------------+-------+",
+            "| name                 | value |",
+            "+----------------------+-------+",
+            "| ballista.client.pull | true  |",
+            "+----------------------+-------+",
+        ];
+
+        assert_batches_eq!(expected, &result);
+
+        let expected = [
+            "+------------+----------+",
+            "| string_col | count(*) |",
+            "+------------+----------+",
+            "| 30         | 1        |",
+            "| 31         | 2        |",
+            "+------------+----------+",
+        ];
+
+        let result = ctx
+            .sql("select string_col, count(*) from test where id > 4 group by 
string_col order by string_col")
+            .await?
+            .collect()
+            .await?;
+
+        assert_batches_eq!(expected, &result);
+
+        Ok(())
+    }
 }
diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index 97812bb0d..641aa9aac 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -759,6 +759,8 @@ service SchedulerGrpc {
 
   rpc RemoveSession (RemoveSessionParams) returns (RemoveSessionResult) {}
 
+  rpc ExecuteQueryPush (ExecuteQueryParams) returns (stream 
GetJobStatusResult) {}
+
   rpc ExecuteQuery (ExecuteQueryParams) returns (ExecuteQueryResult) {}
 
   rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {}
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 15e031a16..ca151558c 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -80,6 +80,8 @@ pub const BALLISTA_SHUFFLE_SORT_BASED_SPILL_THRESHOLD: &str =
 /// Configuration key for sort shuffle target batch size in rows.
 pub const BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE: &str =
     "ballista.shuffle.sort_based.batch_size";
+/// Should client employ pull or push job tracking strategy
+pub const BALLISTA_CLIENT_PULL: &str = "ballista.client.pull";
 
 /// Result type for configuration parsing operations.
 pub type ParseResult<T> = result::Result<T, String>;
@@ -156,7 +158,11 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
         ConfigEntry::new(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE.to_string(),
                          "Target batch size in rows for coalescing small 
batches in sort shuffle".to_string(),
                          DataType::UInt64,
-                         Some((8192).to_string()))
+                         Some((8192).to_string())),
+        ConfigEntry::new(BALLISTA_CLIENT_PULL.to_string(),
+                         "Should client employ pull or push job tracking. In 
pull mode client will make a request to server in the loop, until job finishes. 
Pull mode is kept for legacy clients.".to_string(),
+                         DataType::Boolean,
+                         Some(false.to_string()))
     ];
     entries
         .into_iter()
@@ -362,6 +368,11 @@ impl BallistaConfig {
         self.get_usize_setting(BALLISTA_SHUFFLE_SORT_BASED_BATCH_SIZE)
     }
 
+    /// Should client employ pull or push job tracking strategy
+    pub fn client_pull(&self) -> bool {
+        self.get_bool_setting(BALLISTA_CLIENT_PULL)
+    }
+
     fn get_usize_setting(&self, key: &str) -> usize {
         if let Some(v) = self.settings.get(key) {
             // infallible because we validate all configs in the constructor
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index 8fb6935c4..5f1ad258d 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -247,33 +247,62 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
 
         let session_config = context.session_config().clone();
 
-        let stream = futures::stream::once(
-            execute_query(
-                self.scheduler_url.clone(),
-                self.session_id.clone(),
-                query,
-                self.config.default_grpc_client_max_message_size(),
-                GrpcClientConfig::from(&self.config),
-                Arc::new(self.metrics.clone()),
-                partition,
-                session_config,
+        if session_config.ballista_config().client_pull() {
+            let stream = futures::stream::once(
+                execute_query_pull(
+                    self.scheduler_url.clone(),
+                    self.session_id.clone(),
+                    query,
+                    self.config.default_grpc_client_max_message_size(),
+                    GrpcClientConfig::from(&self.config),
+                    Arc::new(self.metrics.clone()),
+                    partition,
+                    session_config,
+                )
+                .map_err(|e| ArrowError::ExternalError(Box::new(e))),
             )
-            .map_err(|e| ArrowError::ExternalError(Box::new(e))),
-        )
-        .try_flatten()
-        .inspect(move |batch| {
-            metric_total_bytes.add(
-                batch
-                    .as_ref()
-                    .map(|b| b.get_array_memory_size())
-                    .unwrap_or(0),
-            );
-
-            metric_row_count.add(batch.as_ref().map(|b| 
b.num_rows()).unwrap_or(0));
-        });
-
-        let schema = self.schema();
-        Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+            .try_flatten()
+            .inspect(move |batch| {
+                metric_total_bytes.add(
+                    batch
+                        .as_ref()
+                        .map(|b| b.get_array_memory_size())
+                        .unwrap_or(0),
+                );
+
+                metric_row_count.add(batch.as_ref().map(|b| 
b.num_rows()).unwrap_or(0));
+            });
+
+            let schema = self.schema();
+            Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+        } else {
+            let stream = futures::stream::once(
+                execute_query_push(
+                    self.scheduler_url.clone(),
+                    query,
+                    self.config.default_grpc_client_max_message_size(),
+                    GrpcClientConfig::from(&self.config),
+                    Arc::new(self.metrics.clone()),
+                    partition,
+                    session_config,
+                )
+                .map_err(|e| ArrowError::ExternalError(Box::new(e))),
+            )
+            .try_flatten()
+            .inspect(move |batch| {
+                metric_total_bytes.add(
+                    batch
+                        .as_ref()
+                        .map(|b| b.get_array_memory_size())
+                        .unwrap_or(0),
+                );
+
+                metric_row_count.add(batch.as_ref().map(|b| 
b.num_rows()).unwrap_or(0));
+            });
+
+            let schema = self.schema();
+            Ok(Box::pin(RecordBatchStreamAdapter::new(schema, stream)))
+        }
     }
 
     fn statistics(&self) -> Result<Statistics> {
@@ -288,8 +317,11 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
     }
 }
 
+/// Client will periodically invoke scheduler to check
+/// job status. There is preconfigured wait period between
+/// pulls, which increases query latency.
 #[allow(clippy::too_many_arguments)]
-async fn execute_query(
+async fn execute_query_pull(
     scheduler_url: String,
     session_id: String,
     query: ExecuteQueryParams,
@@ -453,6 +485,160 @@ async fn execute_query(
         };
     }
 }
+/// After job is scheduled client waits
+/// for job updates, which are streamed back
+/// from server to client
+#[allow(clippy::too_many_arguments)]
+async fn execute_query_push(
+    scheduler_url: String,
+    query: ExecuteQueryParams,
+    max_message_size: usize,
+    grpc_config: GrpcClientConfig,
+    metrics: Arc<ExecutionPlanMetricsSet>,
+    partition: usize,
+    session_config: SessionConfig,
+) -> Result<impl Stream<Item = Result<RecordBatch>> + Send> {
+    let grpc_interceptor = session_config.ballista_grpc_interceptor();
+    let customize_endpoint =
+        session_config.ballista_override_create_grpc_client_endpoint();
+    let use_tls = session_config.ballista_use_tls();
+
+    // Capture query submission time for total_query_time_ms
+    let query_start_time = std::time::Instant::now();
+
+    info!("Connecting to Ballista scheduler at {scheduler_url}");
+    // TODO reuse the scheduler to avoid connecting to the Ballista scheduler 
again and again
+    let mut endpoint =
+        create_grpc_client_endpoint(scheduler_url.clone(), Some(&grpc_config))
+            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
+    if let Some(ref customize) = customize_endpoint {
+        endpoint = customize
+            .configure_endpoint(endpoint)
+            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+    }
+
+    let connection = endpoint
+        .connect()
+        .await
+        .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
+
+    let mut scheduler = SchedulerGrpcClient::with_interceptor(
+        connection,
+        grpc_interceptor.as_ref().clone(),
+    )
+    .max_encoding_message_size(max_message_size)
+    .max_decoding_message_size(max_message_size);
+
+    let mut query_status_stream = scheduler
+        .execute_query_push(query)
+        .await
+        .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
+        .into_inner();
+
+    let mut prev_status: Option<job_status::Status> = None;
+
+    loop {
+        let item = query_status_stream
+            .next()
+            .await
+            .ok_or(DataFusionError::Execution(
+                "Stream closed without job completing".to_string(),
+            ))?
+            .map_err(|e| DataFusionError::Execution(e.to_string()))?;
+
+        let GetJobStatusResult {
+            status,
+            flight_proxy,
+        } = item;
+        let job_id = status
+            .as_ref()
+            .map(|s| s.job_id.to_owned())
+            .unwrap_or("unknown_job_id".to_string()); // should not happen
+        let status = status.and_then(|s| s.status);
+        let has_status_change = prev_status != status;
+        match status {
+            None => {
+                if has_status_change {
+                    info!("Job {job_id} is in initialization ...");
+                }
+                prev_status = status;
+            }
+            Some(job_status::Status::Queued(_)) => {
+                if has_status_change {
+                    info!("Job {job_id} is queued...");
+                }
+                prev_status = status;
+            }
+            Some(job_status::Status::Running(_)) => {
+                if has_status_change {
+                    info!("Job {job_id} is running...");
+                }
+                prev_status = status;
+            }
+            Some(job_status::Status::Failed(err)) => {
+                let msg = format!("Job {} failed: {}", job_id, err.error);
+                error!("{msg}");
+                break Err(DataFusionError::Execution(msg));
+            }
+            Some(job_status::Status::Successful(SuccessfulJob {
+                queued_at,
+                started_at,
+                ended_at,
+                partition_location,
+                ..
+            })) => {
+                // Calculate job execution time (server-side execution)
+                let job_execution_ms = ended_at.saturating_sub(started_at);
+                let duration = Duration::from_millis(job_execution_ms);
+
+                info!("Job {job_id} finished executing in {duration:?} ");
+
+                // Calculate scheduling time (server-side queue time)
+                // This includes network latency and actual queue time
+                let scheduling_ms = started_at.saturating_sub(queued_at);
+
+                // Calculate total query time (end-to-end from client 
perspective)
+                let total_elapsed = query_start_time.elapsed();
+                let total_ms = total_elapsed.as_millis();
+
+                // Set timing metrics
+                let metric_job_execution = MetricBuilder::new(&metrics)
+                    .gauge("job_execution_time_ms", partition);
+                metric_job_execution.set(job_execution_ms as usize);
+
+                let metric_scheduling =
+                    MetricBuilder::new(&metrics).gauge("job_scheduling_in_ms", 
partition);
+                metric_scheduling.set(scheduling_ms as usize);
+
+                let metric_total_time =
+                    MetricBuilder::new(&metrics).gauge("total_query_time_ms", 
partition);
+                metric_total_time.set(total_ms as usize);
+
+                // Note: data_transfer_time_ms is not set here because 
partition fetching
+                // happens lazily when the stream is consumed, not during 
execute_query.
+                // This could be added in a future enhancement by wrapping the 
stream.
+
+                let streams = partition_location.into_iter().map(move 
|partition| {
+                    let f = fetch_partition(
+                        partition,
+                        max_message_size,
+                        true,
+                        scheduler_url.clone(),
+                        flight_proxy.clone(),
+                        customize_endpoint.clone(),
+                        use_tls,
+                    )
+                    .map_err(|e| ArrowError::ExternalError(Box::new(e)));
+
+                    futures::stream::once(f).try_flatten()
+                });
+
+                break Ok(futures::stream::iter(streams).flatten());
+            }
+        };
+    }
+}
 
 fn get_client_host_port(
     executor_metadata: &ExecutorMetadata,
diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs
index 2055e723f..c9c4ef1d5 100644
--- a/ballista/core/src/lib.rs
+++ b/ballista/core/src/lib.rs
@@ -21,6 +21,8 @@
 use std::sync::Arc;
 
 use datafusion::{execution::runtime_env::RuntimeEnv, prelude::SessionConfig};
+
+use crate::serde::protobuf::JobStatus;
 /// The current version of Ballista, derived from the Cargo package version.
 pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION");
 
@@ -76,3 +78,6 @@ pub type RuntimeProducer = Arc<
 /// It is intended to be used with executor configuration
 ///
 pub type ConfigProducer = Arc<dyn Fn() -> SessionConfig + Send + Sync>;
+
+/// Job Notification Subscriber
+pub type JobStatusSubscriber = tokio::sync::mpsc::Sender<JobStatus>;
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index adeeaa8a5..1cfe102e5 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -1386,6 +1386,35 @@ pub mod scheduler_grpc_client {
                 );
             self.inner.unary(req, path, codec).await
         }
+        pub async fn execute_query_push(
+            &mut self,
+            request: impl tonic::IntoRequest<super::ExecuteQueryParams>,
+        ) -> std::result::Result<
+            
tonic::Response<tonic::codec::Streaming<super::GetJobStatusResult>>,
+            tonic::Status,
+        > {
+            self.inner
+                .ready()
+                .await
+                .map_err(|e| {
+                    tonic::Status::unknown(
+                        format!("Service was not ready: {}", e.into()),
+                    )
+                })?;
+            let codec = tonic_prost::ProstCodec::default();
+            let path = http::uri::PathAndQuery::from_static(
+                "/ballista.protobuf.SchedulerGrpc/ExecuteQueryPush",
+            );
+            let mut req = request.into_request();
+            req.extensions_mut()
+                .insert(
+                    GrpcMethod::new(
+                        "ballista.protobuf.SchedulerGrpc",
+                        "ExecuteQueryPush",
+                    ),
+                );
+            self.inner.server_streaming(req, path, codec).await
+        }
         pub async fn execute_query(
             &mut self,
             request: impl tonic::IntoRequest<super::ExecuteQueryParams>,
@@ -1569,6 +1598,19 @@ pub mod scheduler_grpc_server {
             tonic::Response<super::RemoveSessionResult>,
             tonic::Status,
         >;
+        /// Server streaming response type for the ExecuteQueryPush method.
+        type ExecuteQueryPushStream: tonic::codegen::tokio_stream::Stream<
+                Item = std::result::Result<super::GetJobStatusResult, 
tonic::Status>,
+            >
+            + std::marker::Send
+            + 'static;
+        async fn execute_query_push(
+            &self,
+            request: tonic::Request<super::ExecuteQueryParams>,
+        ) -> std::result::Result<
+            tonic::Response<Self::ExecuteQueryPushStream>,
+            tonic::Status,
+        >;
         async fn execute_query(
             &self,
             request: tonic::Request<super::ExecuteQueryParams>,
@@ -1956,6 +1998,53 @@ pub mod scheduler_grpc_server {
                     };
                     Box::pin(fut)
                 }
+                "/ballista.protobuf.SchedulerGrpc/ExecuteQueryPush" => {
+                    #[allow(non_camel_case_types)]
+                    struct ExecuteQueryPushSvc<T: SchedulerGrpc>(pub Arc<T>);
+                    impl<
+                        T: SchedulerGrpc,
+                    > 
tonic::server::ServerStreamingService<super::ExecuteQueryParams>
+                    for ExecuteQueryPushSvc<T> {
+                        type Response = super::GetJobStatusResult;
+                        type ResponseStream = T::ExecuteQueryPushStream;
+                        type Future = BoxFuture<
+                            tonic::Response<Self::ResponseStream>,
+                            tonic::Status,
+                        >;
+                        fn call(
+                            &mut self,
+                            request: tonic::Request<super::ExecuteQueryParams>,
+                        ) -> Self::Future {
+                            let inner = Arc::clone(&self.0);
+                            let fut = async move {
+                                <T as 
SchedulerGrpc>::execute_query_push(&inner, request)
+                                    .await
+                            };
+                            Box::pin(fut)
+                        }
+                    }
+                    let accept_compression_encodings = 
self.accept_compression_encodings;
+                    let send_compression_encodings = 
self.send_compression_encodings;
+                    let max_decoding_message_size = 
self.max_decoding_message_size;
+                    let max_encoding_message_size = 
self.max_encoding_message_size;
+                    let inner = self.inner.clone();
+                    let fut = async move {
+                        let method = ExecuteQueryPushSvc(inner);
+                        let codec = tonic_prost::ProstCodec::default();
+                        let mut grpc = tonic::server::Grpc::new(codec)
+                            .apply_compression_config(
+                                accept_compression_encodings,
+                                send_compression_encodings,
+                            )
+                            .apply_max_message_size_config(
+                                max_decoding_message_size,
+                                max_encoding_message_size,
+                            );
+                        let res = grpc.server_streaming(method, req).await;
+                        Ok(res)
+                    };
+                    Box::pin(fut)
+                }
                 "/ballista.protobuf.SchedulerGrpc/ExecuteQuery" => {
                     #[allow(non_camel_case_types)]
                     struct ExecuteQuerySvc<T: SchedulerGrpc>(pub Arc<T>);
diff --git a/ballista/scheduler/src/cluster/memory.rs 
b/ballista/scheduler/src/cluster/memory.rs
index 87ef6709f..f3f70e99c 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -23,22 +23,23 @@ use crate::cluster::{
 };
 use crate::state::execution_graph::ExecutionGraphBox;
 use async_trait::async_trait;
-use ballista_core::ConfigProducer;
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::protobuf::{
     AvailableTaskSlots, ExecutorHeartbeat, ExecutorStatus, FailedJob, 
QueuedJob,
     executor_status,
 };
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
+use ballista_core::{ConfigProducer, JobStatusSubscriber};
 use dashmap::DashMap;
 use datafusion::prelude::{SessionConfig, SessionContext};
+use tokio::sync::mpsc::error::TrySendError;
 
 use crate::cluster::event::ClusterEventSender;
 use crate::scheduler_server::{SessionBuilder, timestamp_millis, 
timestamp_secs};
 use crate::state::session_manager::create_datafusion_context;
 use crate::state::task_manager::JobInfoCache;
 use ballista_core::serde::protobuf::job_status::Status;
-use log::{debug, error, info, warn};
+use log::{error, info, warn};
 use std::collections::{HashMap, HashSet};
 use std::ops::DerefMut;
 
@@ -351,7 +352,7 @@ pub struct InMemoryJobState {
     /// In-memory store of queued jobs. Map from Job ID -> (Job Name, 
queued_at timestamp)
     queued_jobs: DashMap<String, (String, u64)>,
     /// In-memory store of running job statuses. Map from Job ID -> JobStatus
-    running_jobs: DashMap<String, JobStatus>,
+    running_jobs: DashMap<String, ExtendedJobStatus>,
     /// `SessionBuilder` for building DataFusion `SessionContext` from 
`BallistaConfig`
     session_builder: SessionBuilder,
     /// Sender of job events
@@ -380,12 +381,42 @@ impl InMemoryJobState {
     }
 }
 
+#[derive(Clone)]
+struct ExtendedJobStatus {
+    status: JobStatus,
+    subscriber: Option<JobStatusSubscriber>,
+}
+
+impl ExtendedJobStatus {
+    fn update_subscribers(&self, status: JobStatus) {
+        let job_id = status.job_id.clone();
+        if let Some(subscriber) = &self.subscriber
+            && matches!(subscriber.try_send(status), 
Err(TrySendError::Full(_)))
+        {
+            // to be considered if we need another task to try to push this 
notification
+            // at the moment, it does not look as necessary as, buffer should 
be big enough for all cases
+            error!(
+                "jobs notification subscriber for job {} is blocked, can't 
deliver status update, job notification will be missed",
+                job_id
+            )
+        }
+    }
+}
+
 #[async_trait]
 impl JobState for InMemoryJobState {
-    async fn submit_job(&self, job_id: String, graph: &ExecutionGraphBox) -> 
Result<()> {
+    async fn submit_job(
+        &self,
+        job_id: String,
+        graph: &ExecutionGraphBox,
+        subscriber: Option<JobStatusSubscriber>,
+    ) -> Result<()> {
         if self.queued_jobs.get(&job_id).is_some() {
-            self.running_jobs
-                .insert(job_id.clone(), graph.status().clone());
+            let status = ExtendedJobStatus {
+                status: graph.status().clone(),
+                subscriber,
+            };
+            self.running_jobs.insert(job_id.clone(), status);
             self.queued_jobs.remove(&job_id);
 
             self.job_event_sender.send(&JobStateEvent::JobAcquired {
@@ -413,7 +444,7 @@ impl JobState for InMemoryJobState {
         }
 
         if let Some(status) = 
self.running_jobs.get(job_id).as_deref().cloned() {
-            return Ok(Some(status));
+            return Ok(Some(status.status));
         }
 
         if let Some((status, _)) = self.completed_jobs.get(job_id).as_deref() {
@@ -442,20 +473,29 @@ impl JobState for InMemoryJobState {
 
     async fn save_job(&self, job_id: &str, graph: &ExecutionGraphBox) -> 
Result<()> {
         let status = graph.status().clone();
-
-        debug!("saving state for job {job_id} with status {:?}", status);
-
         // If job is either successful or failed, save to completed jobs
         if matches!(
             status.status,
             Some(Status::Successful(_)) | Some(Status::Failed(_))
         ) {
+            if let Some((_, job_info)) = self.running_jobs.remove(job_id) {
+                job_info.update_subscribers(status.clone());
+            }
+
             self.completed_jobs
                 .insert(job_id.to_string(), (status.clone(), 
Some(graph.cloned())));
-            self.running_jobs.remove(job_id);
         } else {
             // otherwise update running job
-            self.running_jobs.insert(job_id.to_string(), status.clone());
+            if let Some(mut job_info) = self.running_jobs.get_mut(job_id) {
+                job_info.status = status.clone();
+                // we're cloning subscriber not to await in lock
+                job_info.update_subscribers(status.clone());
+            } else {
+                Err(BallistaError::Internal(format!(
+                    "scheduler state can't find job: {}",
+                    job_id
+                )))?
+            };
         }
 
         // job change event emitted
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index 5d01091cc..5647ff89a 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -37,7 +37,7 @@ use ballista_core::serde::protobuf::{
 };
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata, 
PartitionId};
 use ballista_core::utils::{default_config_producer, default_session_builder};
-use ballista_core::{ConfigProducer, consistent_hash};
+use ballista_core::{ConfigProducer, JobStatusSubscriber, consistent_hash};
 
 use crate::cluster::memory::{InMemoryClusterState, InMemoryJobState};
 
@@ -301,7 +301,12 @@ pub trait JobState: Send + Sync {
     /// Submits a new job to the job state.
     ///
     /// The submitter is assumed to own the job.
-    async fn submit_job(&self, job_id: String, graph: &ExecutionGraphBox) -> 
Result<()>;
+    async fn submit_job(
+        &self,
+        job_id: String,
+        graph: &ExecutionGraphBox,
+        subscriber: Option<JobStatusSubscriber>,
+    ) -> Result<()>;
 
     /// Returns the set of all active job IDs.
     async fn get_jobs(&self) -> Result<HashSet<String>>;
diff --git a/ballista/scheduler/src/cluster/test_util/mod.rs 
b/ballista/scheduler/src/cluster/test_util/mod.rs
index d9e960004..71693930e 100644
--- a/ballista/scheduler/src/cluster/test_util/mod.rs
+++ b/ballista/scheduler/src/cluster/test_util/mod.rs
@@ -89,7 +89,7 @@ impl<S: JobState> JobStateTest<S> {
     /// Submits a job with the given execution graph.
     pub async fn submit_job(self, graph: &ExecutionGraphBox) -> Result<Self> {
         self.state
-            .submit_job(graph.job_id().to_string(), graph)
+            .submit_job(graph.job_id().to_string(), graph, None)
             .await?;
         Ok(self)
     }
diff --git a/ballista/scheduler/src/scheduler_server/event.rs 
b/ballista/scheduler/src/scheduler_server/event.rs
index 77ef7a65e..c6d11fb1b 100644
--- a/ballista/scheduler/src/scheduler_server/event.rs
+++ b/ballista/scheduler/src/scheduler_server/event.rs
@@ -20,7 +20,7 @@ use std::fmt::{Debug, Formatter};
 use datafusion::logical_expr::LogicalPlan;
 
 use crate::state::execution_graph::RunningTaskInfo;
-use ballista_core::serde::protobuf::TaskStatus;
+use ballista_core::{JobStatusSubscriber, serde::protobuf::TaskStatus};
 use datafusion::prelude::SessionContext;
 use std::sync::Arc;
 
@@ -39,6 +39,8 @@ pub enum QueryStageSchedulerEvent {
         plan: Box<LogicalPlan>,
         /// Timestamp when the job was queued.
         queued_at: u64,
+        /// job status subscriber
+        subscriber: Option<JobStatusSubscriber>,
     },
     /// A job has been submitted for execution.
     JobSubmitted {
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 072b8a4d3..4e1e9e5b1 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -17,6 +17,7 @@
 
 use axum::extract::ConnectInfo;
 use ballista_core::config::BALLISTA_JOB_NAME;
+use ballista_core::error::{BallistaError, Result as BResult};
 use ballista_core::extension::SessionConfigHelperExt;
 use ballista_core::serde::protobuf::execute_query_params::Query;
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
@@ -26,16 +27,20 @@ use ballista_core::serde::protobuf::{
     ExecuteQueryFailureResult, ExecuteQueryParams, ExecuteQueryResult,
     ExecuteQuerySuccessResult, ExecutorHeartbeat, ExecutorStoppedParams,
     ExecutorStoppedResult, GetJobStatusParams, GetJobStatusResult, 
HeartBeatParams,
-    HeartBeatResult, PollWorkParams, PollWorkResult, RegisterExecutorParams,
-    RegisterExecutorResult, RemoveSessionParams, RemoveSessionResult,
-    UpdateTaskStatusParams, UpdateTaskStatusResult, 
execute_query_failure_result,
-    execute_query_result,
+    HeartBeatResult, JobStatus, KeyValuePair, PollWorkParams, PollWorkResult,
+    RegisterExecutorParams, RegisterExecutorResult, RemoveSessionParams,
+    RemoveSessionResult, UpdateTaskStatusParams, UpdateTaskStatusResult,
+    execute_query_failure_result, execute_query_result,
 };
 use ballista_core::serde::scheduler::ExecutorMetadata;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use datafusion_proto::physical_plan::AsExecutionPlan;
+use futures::{Stream, StreamExt};
 use log::{debug, error, info, trace, warn};
 use std::net::SocketAddr;
+use std::pin::Pin;
+
+use tokio_stream::wrappers::ReceiverStream;
 
 #[cfg(feature = "substrait")]
 use {
@@ -43,12 +48,14 @@ use {
     datafusion_substrait::serializer::deserialize_bytes,
 };
 
-use std::ops::Deref;
-
 use crate::cluster::{bind_task_bias, bind_task_round_robin};
 use crate::config::TaskDistributionPolicy;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use ballista_core::serde::protobuf::get_job_status_result::FlightProxy;
+use datafusion::logical_expr::LogicalPlan;
+use datafusion::prelude::SessionContext;
+use std::ops::Deref;
+use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 use tonic::{Request, Response, Status};
 
@@ -58,6 +65,9 @@ use crate::scheduler_server::SchedulerServer;
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
     for SchedulerServer<T, U>
 {
+    type ExecuteQueryPushStream =
+        Pin<Box<dyn Stream<Item = Result<GetJobStatusResult, Status>> + Send>>;
+
     async fn poll_work(
         &self,
         request: Request<PollWorkParams>,
@@ -333,10 +343,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         Ok(Response::new(RemoveSessionResult { success: true }))
     }
 
-    async fn execute_query(
+    async fn execute_query_push(
         &self,
-        request: Request<ExecuteQueryParams>,
-    ) -> Result<Response<ExecuteQueryResult>, Status> {
+        request: tonic::Request<ExecuteQueryParams>,
+    ) -> std::result::Result<tonic::Response<Self::ExecuteQueryPushStream>, 
tonic::Status>
+    {
         let query_params = request.into_inner();
         if let ExecuteQueryParams {
             query: Some(query),
@@ -351,44 +362,99 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                 .and_then(|s| s.value.clone())
                 .unwrap_or_default();
 
-            let job_id = self.state.task_manager.generate_job_id();
+            info!(
+                "execution query (PUSH) job received - session_id: 
{session_id}, operation_id: {operation_id}, job_name: {job_name}"
+            );
+
+            let (session_id, session_ctx) = self
+                .create_context(&settings, session_id)
+                .await
+                .map_err(|e| {
+                    Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
+                })?;
+
+            let plan = self.parse_plan(query, &session_ctx).await.map_err(|e| {
+                let msg = format!("Could not parse plan: {e}");
+                error!("{}", msg);
+
+                Status::invalid_argument(msg)
+            })?;
+
+            debug!(
+                "Decoded logical plan for execution:\n{}",
+                plan.display_indent()
+            );
+            log::trace!("setting job name: {job_name}");
+
+            let flight_proxy = self.flight_proxy_config();
+
+            let (subscriber, rx) = tokio::sync::mpsc::channel::<JobStatus>(16);
+            let stream = ReceiverStream::new(rx).map(move |status| {
+                Ok::<_, tonic::Status>(GetJobStatusResult {
+                    status: Some(status),
+                    flight_proxy: flight_proxy.clone(),
+                })
+            });
+
+            let job_id = self
+                .submit_job(&job_name, session_ctx, &plan, Some(subscriber))
+                .await
+                .map_err(|e| {
+                    let msg =
+                        format!("Failed to send JobQueued event for 
{job_name}: {e:?}");
+
+                    error!("{msg}");
+
+                    Status::internal(msg)
+                })?;
 
             info!(
-                "execution query - session_id: {session_id}, operation_id: 
{operation_id}, job_name: {job_name}, job_id: {job_id}"
+                "execution query (PUSH) job submitted - session_id: 
{session_id}, operation_id: {operation_id}, job_name: {job_name}, job_id: 
{job_id}"
             );
 
-            let (session_id, session_ctx) = {
-                let session_config = 
self.state.session_manager.produce_config();
-                let session_config = 
session_config.update_from_key_value_pair(&settings);
+            Ok(Response::new(Box::pin(stream)))
+        } else {
+            Err(Status::internal(
+                "Error processing request, invalid message",
+            ))
+        }
+    }
 
-                let ctx = self
-                    .state
-                    .session_manager
-                    .create_or_update_session(&session_id, &session_config)
-                    .await
-                    .map_err(|e| {
-                        Status::internal(format!(
-                            "Failed to create SessionContext: {e:?}"
-                        ))
-                    })?;
+    async fn execute_query(
+        &self,
+        request: Request<ExecuteQueryParams>,
+    ) -> Result<Response<ExecuteQueryResult>, Status> {
+        let query_params = request.into_inner();
+        if let ExecuteQueryParams {
+            query: Some(query),
+            session_id,
+            operation_id,
+            settings,
+        } = query_params
+        {
+            let job_name = settings
+                .iter()
+                .find(|s| s.key == BALLISTA_JOB_NAME)
+                .and_then(|s| s.value.clone())
+                .unwrap_or_default();
 
-                (session_id, ctx)
-            };
+            info!(
+                "execution query job received - session_id: {session_id}, 
operation_id: {operation_id}, job_name: {job_name}"
+            );
+
+            let (session_id, session_ctx) = self
+                .create_context(&settings, session_id)
+                .await
+                .map_err(|e| {
+                    Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
+                })?;
 
-            let plan = match query {
-                Query::LogicalPlan(message) => {
-                    match T::try_decode(message.as_slice()).and_then(|m| {
-                        m.try_into_logical_plan(
-                            session_ctx.task_ctx().deref(),
-                            self.state.codec.logical_extension_codec(),
-                        )
-                    }) {
-                        Ok(plan) => plan,
-                        Err(e) => {
-                            let msg =
-                                format!("Could not parse logical plan 
protobuf: {e}");
-                            error!("{msg}");
-                            return Ok(Response::new(ExecuteQueryResult {
+            let plan = match self.parse_plan(query, &session_ctx).await {
+                Ok(plan) => plan,
+                Err(e) => {
+                    let msg = format!("Could not parse plan: {e}");
+                    error!("{msg}");
+                    return Ok(Response::new(ExecuteQueryResult {
                                 operation_id,
                                 result: 
Some(execute_query_result::Result::Failure(
                                     ExecuteQueryFailureResult {
@@ -396,38 +462,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                                     },
                                 )),
                             }));
-                        }
-                    }
-                }
-                #[cfg(not(feature = "substrait"))]
-                Query::SubstraitPlan(_) => {
-                    let msg = "Received query type \"Substrait\", enable 
\"substrait\" feature to support Substrait plans.".to_string();
-                    error!("{msg}");
-                    return Ok(Response::new(ExecuteQueryResult {
-                        operation_id,
-                        result: Some(execute_query_result::Result::Failure(
-                            ExecuteQueryFailureResult {
-                                failure: 
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
-                            }
-                        ))
-                    }));
-                }
-                #[cfg(feature = "substrait")]
-                Query::SubstraitPlan(bytes) => {
-                    let plan = deserialize_bytes(bytes).await.map_err(|e| {
-                        let msg = format!("Could not parse substrait plan: 
{e}");
-                        error!("{}", msg);
-                        Status::internal(msg)
-                    })?;
-
-                    let ctx = session_ctx.as_ref().clone();
-                    from_substrait_plan(&ctx.state(), &plan)
-                        .await
-                        .map_err(|e| {
-                            let msg = format!("Could not parse substrait plan: 
{e}");
-                            error!("{}", msg);
-                            Status::internal(msg)
-                        })?
                 }
             };
 
@@ -438,7 +472,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
 
             log::trace!("setting job name: {job_name}");
             let job_id = self
-                .submit_job(&job_name, session_ctx, &plan)
+                .submit_job(&job_name, session_ctx, &plan, None)
                 .await
                 .map_err(|e| {
                     let msg =
@@ -448,6 +482,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                     Status::internal(msg)
                 })?;
 
+            info!(
+                "execution query, job submitted - session_id: {session_id}, 
operation_id: {operation_id}, job_name: {job_name}"
+            );
+
             Ok(Response::new(ExecuteQueryResult {
                 operation_id,
                 result: Some(execute_query_result::Result::Success(
@@ -466,15 +504,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         let job_id = request.into_inner().job_id;
         trace!("Received get_job_status request for job {}", job_id);
 
-        let flight_proxy =
-            self.state
-                .config
-                .advertise_flight_sql_endpoint
-                .clone()
-                .map(|s| match s {
-                    s if s.is_empty() => FlightProxy::Local(true),
-                    s => FlightProxy::External(s),
-                });
+        let flight_proxy = self.flight_proxy_config();
 
         match self.state.task_manager.get_job_status(&job_id).await {
             Ok(status) => Ok(Response::new(GetJobStatusResult {
@@ -569,6 +599,67 @@ fn extract_connect_info<T>(request: &Request<T>) -> 
Option<ConnectInfo<SocketAdd
         .cloned()
 }
 
+impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> 
SchedulerServer<T, U> {
+    async fn create_context(
+        &self,
+        settings: &[KeyValuePair],
+        session_id: String,
+    ) -> BResult<(String, Arc<SessionContext>)> {
+        let session_config = self.state.session_manager.produce_config();
+        let session_config = 
session_config.update_from_key_value_pair(settings);
+
+        let ctx = self
+            .state
+            .session_manager
+            .create_or_update_session(&session_id, &session_config)
+            .await?;
+
+        Ok((session_id, ctx))
+    }
+
+    async fn parse_plan(
+        &self,
+        query: Query,
+        session_ctx: &SessionContext,
+    ) -> BResult<LogicalPlan> {
+        match query {
+            Query::LogicalPlan(message) => T::try_decode(message.as_slice())
+                .and_then(|m| {
+                    m.try_into_logical_plan(
+                        session_ctx.task_ctx().deref(),
+                        self.state.codec.logical_extension_codec(),
+                    )
+                })
+                .map_err(|e| e.into()),
+
+            #[cfg(not(feature = "substrait"))]
+            Query::SubstraitPlan(_) => {
+                Err(BallistaError::NotImplemented("Received query type 
\"Substrait\", enable \"substrait\" feature to support Substrait 
plans.".to_string()))
+            }
+            #[cfg(feature = "substrait")]
+            Query::SubstraitPlan(bytes) => {
+                let plan = deserialize_bytes(bytes).await.map_err(|e| 
BallistaError::DataFusionError(e.into()))?;
+
+                let ctx = session_ctx.clone();
+                from_substrait_plan(&ctx.state(), &plan)
+                    .await
+                    .map_err(|e| e.into())
+            }
+        }
+    }
+
+    fn flight_proxy_config(&self) -> Option<FlightProxy> {
+        self.state
+            .config
+            .advertise_flight_sql_endpoint
+            .clone()
+            .map(|s| match s {
+                s if s.is_empty() => FlightProxy::Local(true),
+                s => FlightProxy::External(s),
+            })
+    }
+}
+
 #[cfg(test)]
 mod test {
     use std::sync::Arc;
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs 
b/ballista/scheduler/src/scheduler_server/mod.rs
index 94b34a1dc..8a0fbec2b 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -18,6 +18,7 @@
 use std::sync::Arc;
 use std::time::{Duration, SystemTime, UNIX_EPOCH};
 
+use ballista_core::JobStatusSubscriber;
 use ballista_core::error::Result;
 use ballista_core::event_loop::{EventLoop, EventSender};
 use ballista_core::serde::BallistaCodec;
@@ -222,10 +223,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
         job_name: &str,
         ctx: Arc<SessionContext>,
         plan: &LogicalPlan,
+        subscriber: Option<JobStatusSubscriber>,
     ) -> Result<String> {
         log::debug!("Received submit request for job {job_name}");
         let job_id = self.state.task_manager.generate_job_id();
-
         self.query_stage_event_loop
             .get_sender()?
             .post_event(QueryStageSchedulerEvent::JobQueued {
@@ -234,6 +235,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
                 session_ctx: ctx,
                 plan: Box::new(plan.clone()),
                 queued_at: timestamp_millis(),
+                subscriber,
             })
             .await?;
 
@@ -410,6 +412,7 @@ mod test {
     use std::sync::Arc;
 
     use ballista_core::extension::SessionConfigExt;
+    use ballista_core::serde::protobuf::job_status::Status;
     use datafusion::arrow::datatypes::{DataType, Field, Schema};
     use datafusion::functions_aggregate::sum::sum;
     use datafusion::logical_expr::{LogicalPlan, col};
@@ -480,7 +483,7 @@ mod test {
         // Submit job
         scheduler
             .state
-            .submit_job(job_id, "", ctx, &plan, 0)
+            .submit_job(job_id, "", ctx, &plan, 0, None)
             .await
             .expect("submitting plan");
 
@@ -590,6 +593,64 @@ mod test {
         Ok(())
     }
 
+    // checks if job subscriber is getting same events
+    #[tokio::test]
+    async fn test_push_scheduling_with_subscriber() -> Result<()> {
+        let plan = test_plan();
+        // this test will fail when AQE scheduling is used.
+        // as AQE will fold plan due to empty scan
+        let metrics_collector = Arc::new(TestMetricsCollector::default());
+
+        let mut test = SchedulerTest::new(
+            SchedulerConfig::default()
+                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+            metrics_collector.clone(),
+            4,
+            1,
+            None,
+        )
+        .await?;
+        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+
+        let (status, job_id) = test
+            .run_with_subscriber("", &plan, Some(tx))
+            .await
+            .expect("running plan");
+
+        match status.status {
+            Some(job_status::Status::Successful(SuccessfulJob {
+                partition_location,
+                ..
+            })) => {
+                assert_eq!(partition_location.len(), 4);
+            }
+            other => {
+                panic!("Expected success status but found {other:?}");
+            }
+        }
+
+        assert_submitted_event(&job_id, &metrics_collector);
+        assert_completed_event(&job_id, &metrics_collector);
+
+        let mut buffer = vec![];
+        rx.recv_many(&mut buffer, 16).await;
+        assert!(!buffer.is_empty());
+
+        let successful_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+        assert!(successful_job.is_some());
+
+        let failed_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+        assert!(failed_job.is_none());
+
+        Ok(())
+    }
+
     // Simulate a task failure and ensure the job status is updated correctly
     #[tokio::test]
     async fn test_job_failure() -> Result<()> {
@@ -665,6 +726,102 @@ mod test {
         Ok(())
     }
 
+    // Simulate a task failure and ensure the job status is updated correctly
+    // it also checks if job subscriber is getting same events
+    #[tokio::test]
+    async fn test_job_failure_subscriber() -> Result<()> {
+        let plan = test_plan();
+
+        let runner = Arc::new(TaskRunnerFn::new(
+            |_executor_id: String, task: MultiTaskDefinition| {
+                let mut statuses = vec![];
+
+                for TaskId {
+                    task_id,
+                    partition_id,
+                    ..
+                } in task.task_ids
+                {
+                    let timestamp = timestamp_millis();
+                    statuses.push(TaskStatus {
+                        task_id,
+                        job_id: task.job_id.clone(),
+                        stage_id: task.stage_id,
+                        stage_attempt_num: task.stage_attempt_num,
+                        partition_id,
+                        launch_time: timestamp,
+                        start_exec_time: timestamp,
+                        end_exec_time: timestamp,
+                        metrics: vec![],
+                        status: Some(task_status::Status::Failed(FailedTask {
+                            error: "ERROR".to_string(),
+                            retryable: false,
+                            count_to_failures: false,
+                            failed_reason: Some(
+                                failed_task::FailedReason::ExecutionError(
+                                    ExecutionError {},
+                                ),
+                            ),
+                        })),
+                    });
+                }
+
+                statuses
+            },
+        ));
+
+        let metrics_collector = Arc::new(TestMetricsCollector::default());
+
+        let mut test = SchedulerTest::new(
+            SchedulerConfig::default()
+                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+            metrics_collector.clone(),
+            4,
+            1,
+            Some(runner),
+        )
+        .await?;
+        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+        let (status, job_id) = test
+            .run_with_subscriber("", &plan, Some(tx))
+            .await
+            .expect("running plan");
+
+        assert!(
+            matches!(
+                status,
+                JobStatus {
+                    status: Some(job_status::Status::Failed(_)),
+                    ..
+                }
+            ),
+            "{}",
+            "Expected job status to be failed but it was {status:?}"
+        );
+
+        assert_submitted_event(&job_id, &metrics_collector);
+        assert_failed_event(&job_id, &metrics_collector);
+
+        let mut buffer = vec![];
+        rx.recv_many(&mut buffer, 16).await;
+
+        assert!(!buffer.is_empty());
+
+        let failed_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+        assert!(failed_job.is_some());
+
+        let successful_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+        assert!(successful_job.is_none());
+
+        Ok(())
+    }
+
     // If the physical planning fails, the job should be marked as failed.
     // Here we simulate a planning failure using ExplodingTableProvider to 
test this.
     #[tokio::test]
@@ -710,6 +867,68 @@ mod test {
         Ok(())
     }
 
+    // If the physical planning fails, the job should be marked as failed.
+    // Here we simulate a planning failure using ExplodingTableProvider to 
test this.
+    // it also checks if job subscriber is getting same events
+    #[tokio::test]
+    async fn test_planning_failure_with_subscriber() -> Result<()> {
+        let metrics_collector = Arc::new(TestMetricsCollector::default());
+        let mut test = SchedulerTest::new(
+            SchedulerConfig::default()
+                .with_scheduler_policy(TaskSchedulingPolicy::PushStaged),
+            metrics_collector.clone(),
+            4,
+            1,
+            None,
+        )
+        .await?;
+
+        let ctx = test.ctx().await?;
+
+        ctx.register_table("explode", Arc::new(ExplodingTableProvider))?;
+
+        let plan = ctx
+            .sql("SELECT * FROM explode")
+            .await?
+            .into_optimized_plan()?;
+        let (tx, mut rx) = tokio::sync::mpsc::channel(16);
+        // This should fail when we try and create the physical plan
+        let (status, job_id) = test.run_with_subscriber("", &plan, 
Some(tx)).await?;
+
+        assert!(
+            matches!(
+                status,
+                JobStatus {
+                    status: Some(job_status::Status::Failed(_)),
+                    ..
+                }
+            ),
+            "{}",
+            "Expected job status to be failed but it was {status:?}"
+        );
+
+        assert_no_submitted_event(&job_id, &metrics_collector);
+        assert_failed_event(&job_id, &metrics_collector);
+
+        let mut buffer = vec![];
+        rx.recv_many(&mut buffer, 16).await;
+        assert!(!buffer.is_empty());
+
+        let failed_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Failed(_))));
+
+        assert!(failed_job.is_some());
+
+        let successful_job = buffer
+            .iter()
+            .find(|s| matches!(s.status, Some(Status::Successful(_))));
+
+        assert!(successful_job.is_none());
+
+        Ok(())
+    }
+
     async fn test_scheduler(
         scheduling_policy: TaskSchedulingPolicy,
     ) -> Result<SchedulerServer<LogicalPlanNode, PhysicalPlanNode>> {
diff --git a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs 
b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
index 97e7f8569..85b0924e2 100644
--- a/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -19,10 +19,12 @@ use std::sync::Arc;
 use std::time::Duration;
 
 use async_trait::async_trait;
+use ballista_core::serde::protobuf::{FailedJob, JobStatus};
 use log::{error, info, trace, warn};
 
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::event_loop::{EventAction, EventSender};
+use tokio::sync::mpsc::error::TrySendError;
 
 use crate::config::SchedulerConfig;
 use crate::metrics::SchedulerMetricsCollector;
@@ -93,6 +95,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
                 session_ctx,
                 plan,
                 queued_at,
+                subscriber,
             } => {
                 info!("Job {job_id} queued with name {job_name:?}");
 
@@ -108,10 +111,42 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>
                 let state = self.state.clone();
                 tokio::spawn(async move {
                     let event = if let Err(e) = state
-                        .submit_job(&job_id, &job_name, session_ctx, &plan, 
queued_at)
+                        .submit_job(
+                            &job_id,
+                            &job_name,
+                            session_ctx,
+                            &plan,
+                            queued_at,
+                            subscriber.clone(),
+                        )
                         .await
                     {
+                        let error = e.to_string();
                         let fail_message = format!("Error planning job 
{job_id}: {e:?}");
+
+                        // this is a corner case, as most of job status 
changes are handled in
+                        // job state, after job is submitted to job state
+                        if let Some(subscriber) = subscriber {
+                            let timestamp = timestamp_millis();
+                            let job_status = JobStatus {
+                                job_id: job_id.clone(),
+                                job_name,
+                                status: 
Some(ballista_core::serde::protobuf::job_status::Status::Failed(
+                                    FailedJob { error, queued_at, started_at: 
timestamp, ended_at: timestamp }
+                                ))
+                            };
+
+                            if matches!(
+                                subscriber.try_send(job_status),
+                                Err(TrySendError::Full(_))
+                            ) {
+                                error!(
+                                    "jobs notification subscriber for job {} 
is blocked, can't deliver status update, job notification will be missed",
+                                    job_id
+                                )
+                            }
+                        }
+
                         error!("{}", &fail_message);
                         QueryStageSchedulerEvent::JobPlanningFailed {
                             job_id,
diff --git a/ballista/scheduler/src/state/mod.rs 
b/ballista/scheduler/src/state/mod.rs
index 49e473c58..82fa60c97 100644
--- a/ballista/scheduler/src/state/mod.rs
+++ b/ballista/scheduler/src/state/mod.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use ballista_core::JobStatusSubscriber;
 use datafusion::common::tree_node::{Transformed, TreeNode, TreeNodeRecursion};
 use datafusion::datasource::listing::{ListingTable, ListingTableUrl};
 use datafusion::datasource::source_as_provider;
@@ -379,6 +380,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerState<T,
         session_ctx: Arc<SessionContext>,
         plan: &LogicalPlan,
         queued_at: u64,
+        subscriber: Option<JobStatusSubscriber>,
     ) -> Result<()> {
         let start = Instant::now();
         let session_config = Arc::new(session_ctx.copied_config());
@@ -487,6 +489,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerState<T,
                 plan.data,
                 queued_at,
                 session_config,
+                subscriber,
             )
             .await?;
 
diff --git a/ballista/scheduler/src/state/task_manager.rs 
b/ballista/scheduler/src/state/task_manager.rs
index ae3686901..b55806407 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -23,6 +23,7 @@ use crate::state::execution_graph::{
 };
 use crate::state::executor_manager::ExecutorManager;
 
+use ballista_core::JobStatusSubscriber;
 use ballista_core::error::BallistaError;
 use ballista_core::error::Result;
 use ballista_core::extension::{SessionConfigExt, SessionConfigHelperExt};
@@ -156,6 +157,7 @@ impl JobInfoCache {
     /// Creates a new `JobInfoCache` from an execution graph.
     pub fn new(graph: ExecutionGraphBox) -> Self {
         let status = graph.status().status.clone();
+
         Self {
             execution_graph: Arc::new(RwLock::new(graph)),
             status,
@@ -266,6 +268,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
     /// Generate an ExecutionGraph for the job and save it to the persistent 
state.
     /// By default, this job will be curated by the scheduler which receives 
it.
     /// Then we will also save it to the active execution graph
+    #[allow(clippy::too_many_arguments)]
     pub async fn submit_job(
         &self,
         job_id: &str,
@@ -274,6 +277,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         plan: Arc<dyn ExecutionPlan>,
         queued_at: u64,
         session_config: Arc<SessionConfig>,
+        subscriber: Option<JobStatusSubscriber>,
     ) -> Result<()> {
         let mut planner = DefaultDistributedPlanner::new();
 
@@ -307,7 +311,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
 
         info!("Submitting execution graph:\n\n{graph:?}");
 
-        self.state.submit_job(job_id.to_string(), &graph).await?;
+        self.state
+            .submit_job(job_id.to_string(), &graph, subscriber)
+            .await?;
         graph.revive();
         self.active_job_cache
             .insert(job_id.to_owned(), JobInfoCache::new(graph));
diff --git a/ballista/scheduler/src/test_utils.rs 
b/ballista/scheduler/src/test_utils.rs
index 7ac6f83a9..1d4f3633f 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use ballista_core::JobStatusSubscriber;
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::extension::SessionConfigExt;
 use datafusion::catalog::Session;
@@ -507,7 +508,7 @@ impl SchedulerTest {
             .create_or_update_session("session_id", &self.session_config)
             .await?;
 
-        let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
+        let job_id = self.scheduler.submit_job(job_name, ctx, plan, 
None).await?;
 
         Ok(job_id)
     }
@@ -627,6 +628,15 @@ impl SchedulerTest {
         &mut self,
         job_name: &str,
         plan: &LogicalPlan,
+    ) -> Result<(JobStatus, String)> {
+        self.run_with_subscriber(job_name, plan, None).await
+    }
+    /// Returns job status and job_id, with provided subscriber
+    pub async fn run_with_subscriber(
+        &mut self,
+        job_name: &str,
+        plan: &LogicalPlan,
+        subscriber: Option<JobStatusSubscriber>,
     ) -> Result<(JobStatus, String)> {
         let ctx = self
             .scheduler
@@ -635,7 +645,10 @@ impl SchedulerTest {
             .create_or_update_session("session_id", &self.session_config)
             .await?;
 
-        let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
+        let job_id = self
+            .scheduler
+            .submit_job(job_name, ctx, plan, subscriber)
+            .await?;
 
         let mut receiver = self.status_receiver.take().unwrap();
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to