This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 30f30428 feat: `ClusterState` does not cache session contexts (#1226)
30f30428 is described below

commit 30f304287fa1e6fbc6c112faec583d10594277c4
Author: Marko Milenković <milenkov...@users.noreply.github.com>
AuthorDate: Sun Jun 1 10:36:41 2025 +0100

    feat: `ClusterState` does not cache session contexts (#1226)
    
    * make scheduler session context stateless ...
    
    ... as session context was not cleaned up.
    
    * rename create_session to create_or_update session
    
    * update proto removing update session and adding create or update
    
    * remove unused code
    
    * add test to cover session events
    
    * expose operation_id and make session_id required
    
    * minor generated code update
    
    * fix standalone context creation
    
    * core extension cleanup
---
 Cargo.lock                                         |   1 +
 ballista/client/src/extension.rs                   |  86 +++-----------
 ballista/core/Cargo.toml                           |   1 +
 ballista/core/proto/ballista.proto                 |  16 ++-
 .../core/src/execution_plans/distributed_query.rs  |  11 +-
 ballista/core/src/extension.rs                     |  62 +++++-----
 ballista/core/src/serde/generated/ballista.rs      | 131 ++++++---------------
 ballista/executor/src/standalone.rs                |   8 +-
 ballista/scheduler/src/cluster/memory.rs           |  87 +++++++++-----
 ballista/scheduler/src/cluster/mod.rs              |  30 +----
 ballista/scheduler/src/scheduler_server/grpc.rs    | 119 ++++++-------------
 ballista/scheduler/src/scheduler_server/mod.rs     |   2 +-
 ballista/scheduler/src/standalone.rs               |  11 +-
 ballista/scheduler/src/state/session_manager.rs    |  23 +---
 ballista/scheduler/src/test_utils.rs               |   6 +-
 examples/tests/common/mod.rs                       |   5 +-
 16 files changed, 213 insertions(+), 386 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 80a47390..b0a77d4d 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -999,6 +999,7 @@ dependencies = [
  "tonic",
  "tonic-build",
  "url",
+ "uuid",
 ]
 
 [[package]]
diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs
index 89e1c7ce..54b21234 100644
--- a/ballista/client/src/extension.rs
+++ b/ballista/client/src/extension.rs
@@ -15,16 +15,10 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use ballista_core::extension::SessionConfigHelperExt;
 pub use ballista_core::extension::{SessionConfigExt, SessionStateExt};
-use ballista_core::{
-    serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, 
CreateSessionParams},
-    utils::create_grpc_client_connection,
-};
+use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
 use datafusion::{
-    error::DataFusionError,
-    execution::SessionState,
-    prelude::{SessionConfig, SessionContext},
+    error::DataFusionError, execution::SessionState, prelude::SessionContext,
 };
 use url::Url;
 
@@ -100,40 +94,34 @@ impl SessionContextExt for SessionContext {
         url: &str,
         state: SessionState,
     ) -> datafusion::error::Result<SessionContext> {
-        let config = state.config();
-
         let scheduler_url = Extension::parse_url(url)?;
         log::info!(
             "Connecting to Ballista scheduler at {}",
             scheduler_url.clone()
         );
-        let remote_session_id =
-            Extension::setup_remote(config, scheduler_url.clone()).await?;
+
+        let session_state = state.upgrade_for_ballista(scheduler_url)?;
+
         log::info!(
             "Server side SessionContext created with session id: {}",
-            remote_session_id
+            session_state.session_id()
         );
-        let session_state =
-            state.upgrade_for_ballista(scheduler_url, remote_session_id)?;
 
         Ok(SessionContext::new_with_state(session_state))
     }
 
     async fn remote(url: &str) -> datafusion::error::Result<SessionContext> {
-        let config = SessionConfig::new_with_ballista();
         let scheduler_url = Extension::parse_url(url)?;
         log::info!(
             "Connecting to Ballista scheduler at: {}",
             scheduler_url.clone()
         );
-        let remote_session_id =
-            Extension::setup_remote(&config, scheduler_url.clone()).await?;
+
+        let session_state = SessionState::new_ballista_state(scheduler_url)?;
         log::info!(
             "Server side SessionContext created with session id: {}",
-            remote_session_id
+            session_state.session_id()
         );
-        let session_state =
-            SessionState::new_ballista_state(scheduler_url, 
remote_session_id)?;
 
         Ok(SessionContext::new_with_state(session_state))
     }
@@ -142,15 +130,13 @@ impl SessionContextExt for SessionContext {
     async fn standalone_with_state(
         state: SessionState,
     ) -> datafusion::error::Result<SessionContext> {
-        let (remote_session_id, scheduler_url) =
-            Extension::setup_standalone(Some(&state)).await?;
+        let scheduler_url = Extension::setup_standalone(Some(&state)).await?;
 
-        let session_state =
-            state.upgrade_for_ballista(scheduler_url, 
remote_session_id.clone())?;
+        let session_state = state.upgrade_for_ballista(scheduler_url)?;
 
         log::info!(
             "Server side SessionContext created with session id: {}",
-            remote_session_id
+            session_state.session_id()
         );
 
         Ok(SessionContext::new_with_state(session_state))
@@ -160,15 +146,13 @@ impl SessionContextExt for SessionContext {
     async fn standalone() -> datafusion::error::Result<Self> {
         log::info!("Running in local mode. Scheduler will be run in-proc");
 
-        let (remote_session_id, scheduler_url) =
-            Extension::setup_standalone(None).await?;
+        let scheduler_url = Extension::setup_standalone(None).await?;
 
-        let session_state =
-            SessionState::new_ballista_state(scheduler_url, 
remote_session_id.clone())?;
+        let session_state = SessionState::new_ballista_state(scheduler_url)?;
 
         log::info!(
             "Server side SessionContext created with session id: {}",
-            remote_session_id
+            session_state.session_id()
         );
 
         Ok(SessionContext::new_with_state(session_state))
@@ -193,7 +177,7 @@ impl Extension {
     #[cfg(feature = "standalone")]
     async fn setup_standalone(
         session_state: Option<&SessionState>,
-    ) -> datafusion::error::Result<(String, String)> {
+    ) -> datafusion::error::Result<String> {
         use ballista_core::{serde::BallistaCodec, 
utils::default_config_producer};
 
         let addr = match session_state {
@@ -214,7 +198,7 @@ impl Extension {
 
         let scheduler_url = format!("http://localhost:{}";, addr.port());
 
-        let mut scheduler = loop {
+        let scheduler = loop {
             match SchedulerGrpcClient::connect(scheduler_url.clone()).await {
                 Err(_) => {
                     
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
@@ -224,15 +208,6 @@ impl Extension {
             }
         };
 
-        let remote_session_id = scheduler
-            .create_session(CreateSessionParams {
-                settings: config.to_key_value_pairs(),
-            })
-            .await
-            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
-            .into_inner()
-            .session_id;
-
         let concurrent_tasks = config.ballista_standalone_parallelism();
 
         match session_state {
@@ -256,31 +231,6 @@ impl Extension {
             }
         }
 
-        Ok((remote_session_id, scheduler_url))
-    }
-
-    async fn setup_remote(
-        config: &SessionConfig,
-        scheduler_url: String,
-    ) -> datafusion::error::Result<String> {
-        let connection = create_grpc_client_connection(scheduler_url.clone())
-            .await
-            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
-
-        let limit = config.ballista_grpc_client_max_message_size();
-        let mut scheduler = SchedulerGrpcClient::new(connection)
-            .max_encoding_message_size(limit)
-            .max_decoding_message_size(limit);
-
-        let remote_session_id = scheduler
-            .create_session(CreateSessionParams {
-                settings: config.to_key_value_pairs(),
-            })
-            .await
-            .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
-            .into_inner()
-            .session_id;
-
-        Ok(remote_session_id)
+        Ok(scheduler_url)
     }
 }
diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml
index ab9b39e4..b05edf3e 100644
--- a/ballista/core/Cargo.toml
+++ b/ballista/core/Cargo.toml
@@ -64,6 +64,7 @@ tokio = { workspace = true }
 tokio-stream = { workspace = true, features = ["net"] }
 tonic = { workspace = true }
 url = { workspace = true }
+uuid = { workspace = true }
 
 [dev-dependencies]
 tempfile = { workspace = true }
diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index 6e1fe131..8acb7751 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -524,15 +524,20 @@ message ExecuteQueryParams {
     string sql = 2 [deprecated=true]; // I'd suggest to remove this, if SQL 
needed use `flight-sql`
   }
   
-  optional string session_id = 3;
+  string session_id = 3;
   repeated KeyValuePair settings = 4;
+  // operation_id is unique number for each request 
+  // client makes. it helps mapping requests between 
+  // client and scheduler
+  string operation_id = 5;
 }
 
-message CreateSessionParams {
+message CreateUpdateSessionParams {
+  string session_id = 2;
   repeated KeyValuePair settings = 1;
 }
 
-message CreateSessionResult {
+message CreateUpdateSessionResult {
   string session_id = 1;
 }
 
@@ -562,6 +567,7 @@ message ExecuteQueryResult {
     ExecuteQuerySuccessResult success = 1;
     ExecuteQueryFailureResult failure = 2;
   }
+  string operation_id = 3;
 }
 
 message ExecuteQuerySuccessResult {
@@ -697,9 +703,7 @@ service SchedulerGrpc {
 
   rpc UpdateTaskStatus (UpdateTaskStatusParams) returns 
(UpdateTaskStatusResult) {}
 
-  rpc CreateSession (CreateSessionParams) returns (CreateSessionResult) {}
-
-  rpc UpdateSession (UpdateSessionParams) returns (UpdateSessionResult) {}
+  rpc CreateUpdateSession (CreateUpdateSessionParams) returns 
(CreateUpdateSessionResult) {}
 
   rpc RemoveSession (RemoveSessionParams) returns (RemoveSessionResult) {}
 
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index 7ccb2727..600042d2 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -43,7 +43,7 @@ use datafusion_proto::logical_plan::{
     AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
 };
 use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
-use log::{error, info};
+use log::{debug, error, info};
 use std::any::Any;
 use std::fmt::Debug;
 use std::marker::PhantomData;
@@ -214,11 +214,16 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
                 },
             )
             .collect();
-
+        let operation_id = uuid::Uuid::now_v7().to_string();
+        debug!(
+            "Distributed query with session_id: {}, execution operation_id: 
{}",
+            self.session_id, operation_id
+        );
         let query = ExecuteQueryParams {
             query: Some(Query::LogicalPlan(buf)),
             settings,
-            session_id: Some(self.session_id.clone()),
+            session_id: self.session_id.clone(),
+            operation_id,
         };
 
         let metric_row_count = 
MetricBuilder::new(&self.metrics).output_rows(partition);
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 74d66713..4f8cde7f 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -38,7 +38,6 @@ pub trait SessionStateExt {
     /// State will be created with appropriate [SessionConfig] configured
     fn new_ballista_state(
         scheduler_url: String,
-        session_id: String,
     ) -> datafusion::error::Result<SessionState>;
     /// Upgrades [SessionState] for ballista usage
     ///
@@ -46,7 +45,6 @@ pub trait SessionStateExt {
     fn upgrade_for_ballista(
         self,
         scheduler_url: String,
-        session_id: String,
     ) -> datafusion::error::Result<SessionState>;
 }
 
@@ -57,6 +55,13 @@ pub trait SessionConfigExt {
     /// ballista configuration initialized
     fn new_with_ballista() -> SessionConfig;
 
+    /// update [SessionConfig] with Ballista specific settings
+    fn upgrade_for_ballista(self) -> SessionConfig;
+
+    /// return ballista specific configuration or
+    /// creates one if does not exist
+    fn ballista_config(&self) -> BallistaConfig;
+
     /// Overrides ballista's [LogicalExtensionCodec]
     fn with_ballista_logical_extension_codec(
         self,
@@ -131,7 +136,6 @@ pub trait SessionConfigHelperExt {
 impl SessionStateExt for SessionState {
     fn new_ballista_state(
         scheduler_url: String,
-        session_id: String,
     ) -> datafusion::error::Result<SessionState> {
         let session_config = SessionConfig::new_with_ballista();
         let planner = BallistaQueryPlanner::<LogicalPlanNode>::new(
@@ -145,7 +149,6 @@ impl SessionStateExt for SessionState {
             .with_config(session_config)
             .with_runtime_env(Arc::new(runtime_env))
             .with_query_planner(Arc::new(planner))
-            .with_session_id(session_id)
             .build();
 
         Ok(session_state)
@@ -154,35 +157,23 @@ impl SessionStateExt for SessionState {
     fn upgrade_for_ballista(
         self,
         scheduler_url: String,
-        session_id: String,
     ) -> datafusion::error::Result<SessionState> {
         let codec_logical = self.config().ballista_logical_extension_codec();
         let planner_override = self.config().ballista_query_planner();
 
-        let new_config = self
-            .config()
-            .options()
-            .extensions
-            .get::<BallistaConfig>()
-            .cloned()
-            .unwrap_or_else(BallistaConfig::default);
+        let session_config = self.config().clone().upgrade_for_ballista();
 
-        let session_config = self
-            .config()
-            .clone()
-            .with_option_extension(new_config.clone())
-            .ballista_restricted_configuration();
+        let ballista_config = session_config.ballista_config();
 
-        let builder = SessionStateBuilder::new_from_existing(self)
-            .with_config(session_config)
-            .with_session_id(session_id);
+        let builder =
+            
SessionStateBuilder::new_from_existing(self).with_config(session_config);
 
         let builder = match planner_override {
             Some(planner) => builder.with_query_planner(planner),
             None => {
                 let planner = 
BallistaQueryPlanner::<LogicalPlanNode>::with_extension(
                     scheduler_url,
-                    new_config,
+                    ballista_config,
                     codec_logical,
                 );
                 builder.with_query_planner(Arc::new(planner))
@@ -201,6 +192,26 @@ impl SessionConfigExt for SessionConfig {
             .with_target_partitions(16)
             .ballista_restricted_configuration()
     }
+
+    fn upgrade_for_ballista(self) -> SessionConfig {
+        // if ballista config is not provided
+        // one is created and session state is updated
+        let ballista_config = self.ballista_config();
+
+        // session config has ballista config extension and
+        // default datafusion configuration is altered
+        // to fit ballista execution
+        self.with_option_extension(ballista_config)
+            .ballista_restricted_configuration()
+    }
+
+    fn ballista_config(&self) -> BallistaConfig {
+        self.options()
+            .extensions
+            .get::<BallistaConfig>()
+            .cloned()
+            .unwrap_or_else(BallistaConfig::default)
+    }
     fn with_ballista_logical_extension_codec(
         self,
         codec: Arc<dyn LogicalExtensionCodec>,
@@ -452,11 +463,8 @@ mod test {
     // Ballista disables round robin repatriations
     #[tokio::test]
     async fn should_disable_round_robin_repartition() {
-        let state = SessionState::new_ballista_state(
-            "scheduler_url".to_string(),
-            "session_id".to_string(),
-        )
-        .unwrap();
+        let state =
+            
SessionState::new_ballista_state("scheduler_url".to_string()).unwrap();
 
         assert!(!state.config().round_robin_repartition());
 
@@ -464,7 +472,7 @@ mod test {
 
         assert!(state.config().round_robin_repartition());
         let state = state
-            .upgrade_for_ballista("scheduler_url".to_string(), 
"session_id".to_string())
+            .upgrade_for_ballista("scheduler_url".to_string())
             .unwrap();
 
         assert!(!state.config().round_robin_repartition());
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index ed73ab40..8809051c 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -786,10 +786,15 @@ pub struct UpdateTaskStatusResult {
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ExecuteQueryParams {
-    #[prost(string, optional, tag = "3")]
-    pub session_id: ::core::option::Option<::prost::alloc::string::String>,
+    #[prost(string, tag = "3")]
+    pub session_id: ::prost::alloc::string::String,
     #[prost(message, repeated, tag = "4")]
     pub settings: ::prost::alloc::vec::Vec<KeyValuePair>,
+    /// operation_id is unique number for each request
+    /// client makes. it helps mapping requests between
+    /// client and scheduler
+    #[prost(string, tag = "5")]
+    pub operation_id: ::prost::alloc::string::String,
     #[prost(oneof = "execute_query_params::Query", tags = "1, 2")]
     pub query: ::core::option::Option<execute_query_params::Query>,
 }
@@ -805,12 +810,14 @@ pub mod execute_query_params {
     }
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
-pub struct CreateSessionParams {
+pub struct CreateUpdateSessionParams {
+    #[prost(string, tag = "2")]
+    pub session_id: ::prost::alloc::string::String,
     #[prost(message, repeated, tag = "1")]
     pub settings: ::prost::alloc::vec::Vec<KeyValuePair>,
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
-pub struct CreateSessionResult {
+pub struct CreateUpdateSessionResult {
     #[prost(string, tag = "1")]
     pub session_id: ::prost::alloc::string::String,
 }
@@ -843,6 +850,8 @@ pub struct ExecuteSqlParams {
 }
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ExecuteQueryResult {
+    #[prost(string, tag = "3")]
+    pub operation_id: ::prost::alloc::string::String,
     #[prost(oneof = "execute_query_result::Result", tags = "1, 2")]
     pub result: ::core::option::Option<execute_query_result::Result>,
 }
@@ -1230,37 +1239,11 @@ pub mod scheduler_grpc_client {
                 );
             self.inner.unary(req, path, codec).await
         }
-        pub async fn create_session(
-            &mut self,
-            request: impl tonic::IntoRequest<super::CreateSessionParams>,
-        ) -> std::result::Result<
-            tonic::Response<super::CreateSessionResult>,
-            tonic::Status,
-        > {
-            self.inner
-                .ready()
-                .await
-                .map_err(|e| {
-                    tonic::Status::unknown(
-                        format!("Service was not ready: {}", e.into()),
-                    )
-                })?;
-            let codec = tonic::codec::ProstCodec::default();
-            let path = http::uri::PathAndQuery::from_static(
-                "/ballista.protobuf.SchedulerGrpc/CreateSession",
-            );
-            let mut req = request.into_request();
-            req.extensions_mut()
-                .insert(
-                    GrpcMethod::new("ballista.protobuf.SchedulerGrpc", 
"CreateSession"),
-                );
-            self.inner.unary(req, path, codec).await
-        }
-        pub async fn update_session(
+        pub async fn create_update_session(
             &mut self,
-            request: impl tonic::IntoRequest<super::UpdateSessionParams>,
+            request: impl tonic::IntoRequest<super::CreateUpdateSessionParams>,
         ) -> std::result::Result<
-            tonic::Response<super::UpdateSessionResult>,
+            tonic::Response<super::CreateUpdateSessionResult>,
             tonic::Status,
         > {
             self.inner
@@ -1273,12 +1256,15 @@ pub mod scheduler_grpc_client {
                 })?;
             let codec = tonic::codec::ProstCodec::default();
             let path = http::uri::PathAndQuery::from_static(
-                "/ballista.protobuf.SchedulerGrpc/UpdateSession",
+                "/ballista.protobuf.SchedulerGrpc/CreateUpdateSession",
             );
             let mut req = request.into_request();
             req.extensions_mut()
                 .insert(
-                    GrpcMethod::new("ballista.protobuf.SchedulerGrpc", 
"UpdateSession"),
+                    GrpcMethod::new(
+                        "ballista.protobuf.SchedulerGrpc",
+                        "CreateUpdateSession",
+                    ),
                 );
             self.inner.unary(req, path, codec).await
         }
@@ -1698,18 +1684,11 @@ pub mod scheduler_grpc_server {
             tonic::Response<super::UpdateTaskStatusResult>,
             tonic::Status,
         >;
-        async fn create_session(
+        async fn create_update_session(
             &self,
-            request: tonic::Request<super::CreateSessionParams>,
+            request: tonic::Request<super::CreateUpdateSessionParams>,
         ) -> std::result::Result<
-            tonic::Response<super::CreateSessionResult>,
-            tonic::Status,
-        >;
-        async fn update_session(
-            &self,
-            request: tonic::Request<super::UpdateSessionParams>,
-        ) -> std::result::Result<
-            tonic::Response<super::UpdateSessionResult>,
+            tonic::Response<super::CreateUpdateSessionResult>,
             tonic::Status,
         >;
         async fn remove_session(
@@ -2015,70 +1994,26 @@ pub mod scheduler_grpc_server {
                     };
                     Box::pin(fut)
                 }
-                "/ballista.protobuf.SchedulerGrpc/CreateSession" => {
-                    #[allow(non_camel_case_types)]
-                    struct CreateSessionSvc<T: SchedulerGrpc>(pub Arc<T>);
-                    impl<
-                        T: SchedulerGrpc,
-                    > tonic::server::UnaryService<super::CreateSessionParams>
-                    for CreateSessionSvc<T> {
-                        type Response = super::CreateSessionResult;
-                        type Future = BoxFuture<
-                            tonic::Response<Self::Response>,
-                            tonic::Status,
-                        >;
-                        fn call(
-                            &mut self,
-                            request: 
tonic::Request<super::CreateSessionParams>,
-                        ) -> Self::Future {
-                            let inner = Arc::clone(&self.0);
-                            let fut = async move {
-                                <T as SchedulerGrpc>::create_session(&inner, 
request).await
-                            };
-                            Box::pin(fut)
-                        }
-                    }
-                    let accept_compression_encodings = 
self.accept_compression_encodings;
-                    let send_compression_encodings = 
self.send_compression_encodings;
-                    let max_decoding_message_size = 
self.max_decoding_message_size;
-                    let max_encoding_message_size = 
self.max_encoding_message_size;
-                    let inner = self.inner.clone();
-                    let fut = async move {
-                        let method = CreateSessionSvc(inner);
-                        let codec = tonic::codec::ProstCodec::default();
-                        let mut grpc = tonic::server::Grpc::new(codec)
-                            .apply_compression_config(
-                                accept_compression_encodings,
-                                send_compression_encodings,
-                            )
-                            .apply_max_message_size_config(
-                                max_decoding_message_size,
-                                max_encoding_message_size,
-                            );
-                        let res = grpc.unary(method, req).await;
-                        Ok(res)
-                    };
-                    Box::pin(fut)
-                }
-                "/ballista.protobuf.SchedulerGrpc/UpdateSession" => {
+                "/ballista.protobuf.SchedulerGrpc/CreateUpdateSession" => {
                     #[allow(non_camel_case_types)]
-                    struct UpdateSessionSvc<T: SchedulerGrpc>(pub Arc<T>);
+                    struct CreateUpdateSessionSvc<T: SchedulerGrpc>(pub 
Arc<T>);
                     impl<
                         T: SchedulerGrpc,
-                    > tonic::server::UnaryService<super::UpdateSessionParams>
-                    for UpdateSessionSvc<T> {
-                        type Response = super::UpdateSessionResult;
+                    > 
tonic::server::UnaryService<super::CreateUpdateSessionParams>
+                    for CreateUpdateSessionSvc<T> {
+                        type Response = super::CreateUpdateSessionResult;
                         type Future = BoxFuture<
                             tonic::Response<Self::Response>,
                             tonic::Status,
                         >;
                         fn call(
                             &mut self,
-                            request: 
tonic::Request<super::UpdateSessionParams>,
+                            request: 
tonic::Request<super::CreateUpdateSessionParams>,
                         ) -> Self::Future {
                             let inner = Arc::clone(&self.0);
                             let fut = async move {
-                                <T as SchedulerGrpc>::update_session(&inner, 
request).await
+                                <T as 
SchedulerGrpc>::create_update_session(&inner, request)
+                                    .await
                             };
                             Box::pin(fut)
                         }
@@ -2089,7 +2024,7 @@ pub mod scheduler_grpc_server {
                     let max_encoding_message_size = 
self.max_encoding_message_size;
                     let inner = self.inner.clone();
                     let fut = async move {
-                        let method = UpdateSessionSvc(inner);
+                        let method = CreateUpdateSessionSvc(inner);
                         let codec = tonic::codec::ProstCodec::default();
                         let mut grpc = tonic::server::Grpc::new(codec)
                             .apply_compression_config(
diff --git a/ballista/executor/src/standalone.rs 
b/ballista/executor/src/standalone.rs
index 38c46d02..72adfe08 100644
--- a/ballista/executor/src/standalone.rs
+++ b/ballista/executor/src/standalone.rs
@@ -18,7 +18,6 @@
 use crate::metrics::LoggingMetricsCollector;
 use crate::{execution_loop, executor::Executor, 
flight_service::BallistaFlightService};
 use arrow_flight::flight_service_server::FlightServiceServer;
-use ballista_core::config::BallistaConfig;
 use ballista_core::extension::SessionConfigExt;
 use ballista_core::registry::BallistaFunctionRegistry;
 use ballista_core::utils::default_config_producer;
@@ -57,12 +56,7 @@ pub async fn new_standalone_executor_from_state(
         datafusion_proto::protobuf::PhysicalPlanNode,
     > = BallistaCodec::new(logical, physical);
 
-    let config = session_state
-        .config()
-        .clone()
-        .with_option_extension(BallistaConfig::default()) // TODO: do we need 
this statement
-        ;
-
+    let config = session_state.config().clone().upgrade_for_ballista();
     let runtime = session_state.runtime_env().clone();
 
     let config_producer: ConfigProducer = Arc::new(move || config.clone());
diff --git a/ballista/scheduler/src/cluster/memory.rs 
b/ballista/scheduler/src/cluster/memory.rs
index 071a48d8..4ea39dbb 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -346,8 +346,6 @@ pub struct InMemoryJobState {
     queued_jobs: DashMap<String, (String, u64)>,
     /// In-memory store of running job statuses. Map from Job ID -> JobStatus
     running_jobs: DashMap<String, JobStatus>,
-    /// Active ballista sessions
-    sessions: DashMap<String, Arc<SessionContext>>,
     /// `SessionBuilder` for building DataFusion `SessionContext` from 
`BallistaConfig`
     session_builder: SessionBuilder,
     /// Sender of job events
@@ -367,7 +365,7 @@ impl InMemoryJobState {
             completed_jobs: Default::default(),
             queued_jobs: Default::default(),
             running_jobs: Default::default(),
-            sessions: Default::default(),
+            //sessions: Default::default(),
             session_builder,
             job_event_sender: ClusterEventSender::new(100),
             config_producer,
@@ -460,42 +458,27 @@ impl JobState for InMemoryJobState {
         Ok(())
     }
 
-    async fn get_session(&self, session_id: &str) -> 
Result<Arc<SessionContext>> {
-        self.sessions
-            .get(session_id)
-            .map(|session_ctx| session_ctx.clone())
-            .ok_or_else(|| {
-                BallistaError::General(format!("No session for {session_id} 
found"))
-            })
-    }
-
-    async fn create_session(
-        &self,
-        config: &SessionConfig,
-    ) -> Result<Arc<SessionContext>> {
-        let session = create_datafusion_context(config, 
self.session_builder.clone())?;
-        self.sessions.insert(session.session_id(), session.clone());
-
-        Ok(session)
-    }
-
-    async fn update_session(
+    async fn create_or_update_session(
         &self,
         session_id: &str,
         config: &SessionConfig,
     ) -> Result<Arc<SessionContext>> {
-        let session = create_datafusion_context(config, 
self.session_builder.clone())?;
-        self.sessions
-            .insert(session_id.to_string(), session.clone());
+        self.job_event_sender.send(&JobStateEvent::SessionAccessed {
+            session_id: session_id.to_string(),
+        });
 
-        Ok(session)
+        Ok(create_datafusion_context(
+            config,
+            self.session_builder.clone(),
+        )?)
     }
 
-    async fn remove_session(
-        &self,
-        session_id: &str,
-    ) -> Result<Option<Arc<SessionContext>>> {
-        Ok(self.sessions.remove(session_id).map(|(_key, value)| value))
+    async fn remove_session(&self, session_id: &str) -> Result<()> {
+        self.job_event_sender.send(&JobStateEvent::SessionRemoved {
+            session_id: session_id.to_string(),
+        });
+
+        Ok(())
     }
 
     async fn job_state_events(&self) -> Result<JobStateEventStream> {
@@ -574,6 +557,7 @@ mod test {
     use ballista_core::serde::protobuf::JobStatus;
     use ballista_core::serde::scheduler::{ExecutorMetadata, 
ExecutorSpecification};
     use ballista_core::utils::{default_config_producer, 
default_session_builder};
+    use datafusion::prelude::SessionConfig;
     use futures::StreamExt;
     use tokio::sync::Barrier;
 
@@ -728,4 +712,43 @@ mod test {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_in_memory_session_notification() -> Result<()> {
+        let state = InMemoryJobState::new(
+            "",
+            Arc::new(default_session_builder),
+            Arc::new(default_config_producer),
+        );
+
+        let event_stream = state.job_state_events().await?;
+
+        state
+            .create_or_update_session("session_id_0", &SessionConfig::new())
+            .await?;
+
+        state.remove_session("session_id_0").await?;
+
+        state
+            .create_or_update_session("session_id_1", &SessionConfig::new())
+            .await?;
+
+        let result = 
event_stream.take(3).collect::<Vec<JobStateEvent>>().await;
+        assert_eq!(3, result.len());
+        let expected = vec![
+            JobStateEvent::SessionAccessed {
+                session_id: "session_id_0".to_string(),
+            },
+            JobStateEvent::SessionRemoved {
+                session_id: "session_id_0".to_string(),
+            },
+            JobStateEvent::SessionAccessed {
+                session_id: "session_id_1".to_string(),
+            },
+        ];
+
+        assert_eq!(expected, result);
+
+        Ok(())
+    }
 }
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index 7319b5c8..2b1a2f2b 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -30,7 +30,6 @@ use datafusion::prelude::{SessionConfig, SessionContext};
 use futures::Stream;
 use log::debug;
 
-use ballista_core::config::BallistaConfig;
 use ballista_core::consistent_hash::ConsistentHash;
 use ballista_core::error::Result;
 use ballista_core::serde::protobuf::{
@@ -232,15 +231,9 @@ pub enum JobStateEvent {
         job_id: String,
     },
     /// Event when a new session has been created
-    SessionCreated {
-        session_id: String,
-        config: BallistaConfig,
-    },
-    /// Event when a session configuration has been updated
-    SessionUpdated {
-        session_id: String,
-        config: BallistaConfig,
-    },
+    SessionAccessed { session_id: String },
+    /// Event when a session configuration has been removed
+    SessionRemoved { session_id: String },
 }
 
 /// Events related to the state of cluster.
@@ -308,25 +301,14 @@ pub trait JobState: Send + Sync {
     /// of a job changes in state
     async fn job_state_events(&self) -> Result<JobStateEventStream>;
 
-    /// Get the `SessionContext` associated with `session_id`. Returns an 
error if the
-    /// session does not exist
-    async fn get_session(&self, session_id: &str) -> 
Result<Arc<SessionContext>>;
-
-    /// Create a new saved session
-    async fn create_session(&self, config: &SessionConfig)
-        -> Result<Arc<SessionContext>>;
-
-    /// Update a new saved session. If the session does not exist, a new one 
will be created
-    async fn update_session(
+    /// Create new session or update existing one
+    async fn create_or_update_session(
         &self,
         session_id: &str,
         config: &SessionConfig,
     ) -> Result<Arc<SessionContext>>;
 
-    async fn remove_session(
-        &self,
-        session_id: &str,
-    ) -> Result<Option<Arc<SessionContext>>>;
+    async fn remove_session(&self, session_id: &str) -> Result<()>;
 
     // TODO MM not sure this is the best place to put config producer
     fn produce_config(&self) -> SessionConfig;
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 1c3c922b..7322304a 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -23,13 +23,12 @@ use 
ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
 use ballista_core::serde::protobuf::{
     execute_query_failure_result, execute_query_result, AvailableTaskSlots,
     CancelJobParams, CancelJobResult, CleanJobDataParams, CleanJobDataResult,
-    CreateSessionParams, CreateSessionResult, ExecuteQueryFailureResult,
+    CreateUpdateSessionParams, CreateUpdateSessionResult, 
ExecuteQueryFailureResult,
     ExecuteQueryParams, ExecuteQueryResult, ExecuteQuerySuccessResult, 
ExecutorHeartbeat,
     ExecutorStoppedParams, ExecutorStoppedResult, GetJobStatusParams, 
GetJobStatusResult,
     HeartBeatParams, HeartBeatResult, PollWorkParams, PollWorkResult,
     RegisterExecutorParams, RegisterExecutorResult, RemoveSessionParams,
-    RemoveSessionResult, UpdateSessionParams, UpdateSessionResult,
-    UpdateTaskStatusParams, UpdateTaskStatusResult,
+    RemoveSessionResult, UpdateTaskStatusParams, UpdateTaskStatusResult,
 };
 use ballista_core::serde::scheduler::ExecutorMetadata;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -272,51 +271,27 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         Ok(Response::new(UpdateTaskStatusResult { success: true }))
     }
 
-    async fn create_session(
+    async fn create_update_session(
         &self,
-        request: Request<CreateSessionParams>,
-    ) -> Result<Response<CreateSessionResult>, Status> {
+        request: Request<CreateUpdateSessionParams>,
+    ) -> Result<Response<CreateUpdateSessionResult>, Status> {
         let session_params = request.into_inner();
 
         let session_config = self.state.session_manager.produce_config();
         let session_config =
             
session_config.update_from_key_value_pair(&session_params.settings);
 
-        let ctx = self
+        let _ = self
             .state
             .session_manager
-            .create_session(&session_config)
-            .await
-            .map_err(|e| {
-                Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
-            })?;
+            .create_or_update_session(&session_params.session_id, 
&session_config)
+            .await;
 
-        Ok(Response::new(CreateSessionResult {
-            session_id: ctx.session_id(),
+        Ok(Response::new(CreateUpdateSessionResult {
+            session_id: session_params.session_id,
         }))
     }
 
-    async fn update_session(
-        &self,
-        request: Request<UpdateSessionParams>,
-    ) -> Result<Response<UpdateSessionResult>, Status> {
-        let session_params = request.into_inner();
-
-        let session_config = self.state.session_manager.produce_config();
-        let session_config =
-            
session_config.update_from_key_value_pair(&session_params.settings);
-
-        self.state
-            .session_manager
-            .update_session(&session_params.session_id, &session_config)
-            .await
-            .map_err(|e| {
-                Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
-            })?;
-
-        Ok(Response::new(UpdateSessionResult { success: true }))
-    }
-
     async fn remove_session(
         &self,
         request: Request<RemoveSessionParams>,
@@ -344,6 +319,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         if let ExecuteQueryParams {
             query: Some(query),
             session_id,
+            operation_id,
             settings,
         } = query_params
         {
@@ -353,58 +329,26 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                 .and_then(|s| s.value.clone())
                 .unwrap_or_default();
 
-            let (session_id, session_ctx) = match session_id {
-                Some(session_id) => {
-                    match 
self.state.session_manager.get_session(&session_id).await {
-                        Ok(ctx) => {
-                            // Update [SessionConfig] using received properties
-
-                            // TODO MM can we do something better here?
-                            // move this to update session and use 
.update_session(&session_params.session_id, &session_config)
-                            // instead of get_session.
-                            //
-                            // also we should consider sending properties 
if/when changed rather than
-                            // all properties every time
-
-                            let state = ctx.state_ref();
-                            let mut state = state.write();
-                            let config = state.config_mut();
-                            config.update_from_key_value_pair_mut(&settings);
-
-                            (session_id, ctx)
-                        }
-                        Err(e) => {
-                            let msg = format!("Failed to load SessionContext 
for session ID {session_id}: {e}");
-                            error!("{}", msg);
-                            return Ok(Response::new(ExecuteQueryResult {
-                                result: 
Some(execute_query_result::Result::Failure(
-                                    ExecuteQueryFailureResult {
-                                        failure: 
Some(execute_query_failure_result::Failure::SessionNotFound(msg)),
-                                    },
-                                )),
-                            }));
-                        }
-                    }
-                }
-                _ => {
-                    // Create default config
-                    let session_config = 
self.state.session_manager.produce_config();
-                    let session_config =
-                        session_config.update_from_key_value_pair(&settings);
-
-                    let ctx = self
-                        .state
-                        .session_manager
-                        .create_session(&session_config)
-                        .await
-                        .map_err(|e| {
-                            Status::internal(format!(
-                                "Failed to create SessionContext: {e:?}"
-                            ))
-                        })?;
+            let job_id = self.state.task_manager.generate_job_id();
 
-                    (ctx.session_id(), ctx)
-                }
+            info!("execution query - session_id: {}, operation_id: {}, 
job_name: {}, job_id: {}", session_id, operation_id, job_name, job_id);
+
+            let (session_id, session_ctx) = {
+                let session_config = 
self.state.session_manager.produce_config();
+                let session_config = 
session_config.update_from_key_value_pair(&settings);
+
+                let ctx = self
+                    .state
+                    .session_manager
+                    .create_or_update_session(&session_id, &session_config)
+                    .await
+                    .map_err(|e| {
+                        Status::internal(format!(
+                            "Failed to create SessionContext: {e:?}"
+                        ))
+                    })?;
+
+                (session_id, ctx)
             };
 
             let plan = match query {
@@ -421,6 +365,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                                 format!("Could not parse logical plan 
protobuf: {e}");
                             error!("{}", msg);
                             return Ok(Response::new(ExecuteQueryResult {
+                                operation_id,
                                 result: 
Some(execute_query_result::Result::Failure(
                                     ExecuteQueryFailureResult {
                                         failure: 
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
@@ -441,6 +386,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                             let msg = format!("Error parsing SQL: {e}");
                             error!("{}", msg);
                             return Ok(Response::new(ExecuteQueryResult {
+                                operation_id,
                                 result: 
Some(execute_query_result::Result::Failure(
                                     ExecuteQueryFailureResult {
                                         failure: 
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
@@ -470,6 +416,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                 })?;
 
             Ok(Response::new(ExecuteQueryResult {
+                operation_id,
                 result: Some(execute_query_result::Result::Success(
                     ExecuteQuerySuccessResult { job_id, session_id },
                 )),
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs 
b/ballista/scheduler/src/scheduler_server/mod.rs
index e732145d..d19abd2d 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -420,7 +420,7 @@ mod test {
         let ctx = scheduler
             .state
             .session_manager
-            .create_session(&config)
+            .create_or_update_session("session_id", &config)
             .await?;
 
         let job_id = "job";
diff --git a/ballista/scheduler/src/standalone.rs 
b/ballista/scheduler/src/standalone.rs
index f7d12152..329f32a5 100644
--- a/ballista/scheduler/src/standalone.rs
+++ b/ballista/scheduler/src/standalone.rs
@@ -29,7 +29,7 @@ use ballista_core::{
     error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer,
     BALLISTA_VERSION,
 };
-use datafusion::execution::{SessionState, SessionStateBuilder};
+use datafusion::execution::SessionState;
 use datafusion::prelude::SessionConfig;
 use datafusion_proto::protobuf::LogicalPlanNode;
 use datafusion_proto::protobuf::PhysicalPlanNode;
@@ -56,14 +56,7 @@ pub async fn new_standalone_scheduler_from_state(
     let codec = BallistaCodec::new(logical, physical);
     let session_config = session_state.config().clone();
     let session_state = session_state.clone();
-    let session_builder = Arc::new(move |c: SessionConfig| {
-        Ok(
-            SessionStateBuilder::new_from_existing(session_state.clone())
-                .with_config(c)
-                .build(),
-        )
-    });
-
+    let session_builder = Arc::new(move |_: SessionConfig| 
Ok(session_state.clone()));
     let config_producer = Arc::new(move || session_config.clone());
 
     new_standalone_scheduler_with_builder(session_builder, config_producer, 
codec).await
diff --git a/ballista/scheduler/src/state/session_manager.rs 
b/ballista/scheduler/src/state/session_manager.rs
index 59813167..7538aff6 100644
--- a/ballista/scheduler/src/state/session_manager.rs
+++ b/ballista/scheduler/src/state/session_manager.rs
@@ -31,31 +31,18 @@ impl SessionManager {
     pub fn new(state: Arc<dyn JobState>) -> Self {
         Self { state }
     }
-
-    pub async fn remove_session(
-        &self,
-        session_id: &str,
-    ) -> Result<Option<Arc<SessionContext>>> {
+    pub async fn remove_session(&self, session_id: &str) -> Result<()> {
         self.state.remove_session(session_id).await
     }
 
-    pub async fn update_session(
+    pub async fn create_or_update_session(
         &self,
         session_id: &str,
         config: &SessionConfig,
     ) -> Result<Arc<SessionContext>> {
-        self.state.update_session(session_id, config).await
-    }
-
-    pub async fn create_session(
-        &self,
-        config: &SessionConfig,
-    ) -> Result<Arc<SessionContext>> {
-        self.state.create_session(config).await
-    }
-
-    pub async fn get_session(&self, session_id: &str) -> 
Result<Arc<SessionContext>> {
-        self.state.get_session(session_id).await
+        self.state
+            .create_or_update_session(session_id, config)
+            .await
     }
 
     pub(crate) fn produce_config(&self) -> SessionConfig {
diff --git a/ballista/scheduler/src/test_utils.rs 
b/ballista/scheduler/src/test_utils.rs
index c97fe1d3..45d70cb6 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -476,7 +476,7 @@ impl SchedulerTest {
         self.scheduler
             .state
             .session_manager
-            .create_session(&self.session_config)
+            .create_or_update_session("session_id", &self.session_config)
             .await
     }
 
@@ -486,7 +486,7 @@ impl SchedulerTest {
             .scheduler
             .state
             .session_manager
-            .create_session(&self.session_config)
+            .create_or_update_session("session_id", &self.session_config)
             .await?;
 
         let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
@@ -609,7 +609,7 @@ impl SchedulerTest {
             .scheduler
             .state
             .session_manager
-            .create_session(&self.session_config)
+            .create_or_update_session("session_id", &self.session_config)
             .await?;
 
         let job_id = self.scheduler.submit_job(job_name, ctx, plan).await?;
diff --git a/examples/tests/common/mod.rs b/examples/tests/common/mod.rs
index 1e8091ed..710829e3 100644
--- a/examples/tests/common/mod.rs
+++ b/examples/tests/common/mod.rs
@@ -168,10 +168,7 @@ fn init() {
     // Enable RUST_LOG logging configuration for test
     let _ = env_logger::builder()
         .filter_level(log::LevelFilter::Info)
-        .parse_filters(
-            
"ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug",
-        )
-        
//.parse_filters("ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug,datafusion=debug")
+        
.parse_filters("ballista=debug,ballista_scheduler=debug,ballista_executor=debug")
         .is_test(true)
         .try_init();
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to