This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 01bcf7c4 Refine the ExecuteQuery grpc interface (#790)
01bcf7c4 is described below

commit 01bcf7c43e0e8ac729099d83fcc50017c276ebc4
Author: yahoNanJing <[email protected]>
AuthorDate: Tue May 30 10:02:40 2023 +0800

    Refine the ExecuteQuery grpc interface (#790)
    
    * Decouple session related operations from the execute_query interface
    
    * Add failure details for grpc interface execute_query
    
    ---------
    
    Co-authored-by: yangzhong <[email protected]>
---
 ballista/client/src/context.rs                     |  10 +-
 ballista/core/proto/ballista.proto                 |  46 +++
 .../core/src/execution_plans/distributed_query.rs  |  25 +-
 ballista/core/src/serde/generated/ballista.rs      | 314 +++++++++++++++++++++
 ballista/scheduler/src/cluster/kv.rs               |  11 +
 ballista/scheduler/src/cluster/memory.rs           |   7 +
 ballista/scheduler/src/cluster/mod.rs              |   5 +
 ballista/scheduler/src/scheduler_server/grpc.rs    | 222 ++++++++++-----
 ballista/scheduler/src/state/session_manager.rs    |   7 +
 9 files changed, 555 insertions(+), 92 deletions(-)

diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index 99b7d9b1..94171a6d 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -28,7 +28,7 @@ use std::sync::Arc;
 
 use ballista_core::config::BallistaConfig;
 use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
-use ballista_core::serde::protobuf::{ExecuteQueryParams, KeyValuePair};
+use ballista_core::serde::protobuf::{CreateSessionParams, KeyValuePair};
 use ballista_core::utils::{
     create_df_ctx_with_ballista_query_planner, create_grpc_client_connection,
 };
@@ -103,8 +103,7 @@ impl BallistaContext {
         let mut scheduler = SchedulerGrpcClient::new(connection);
 
         let remote_session_id = scheduler
-            .execute_query(ExecuteQueryParams {
-                query: None,
+            .create_session(CreateSessionParams {
                 settings: config
                     .settings()
                     .iter()
@@ -113,7 +112,6 @@ impl BallistaContext {
                         value: v.to_owned(),
                     })
                     .collect::<Vec<_>>(),
-                optional_session_id: None,
             })
             .await
             .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
@@ -162,8 +160,7 @@ impl BallistaContext {
         };
 
         let remote_session_id = scheduler
-            .execute_query(ExecuteQueryParams {
-                query: None,
+            .create_session(CreateSessionParams {
                 settings: config
                     .settings()
                     .iter()
@@ -172,7 +169,6 @@ impl BallistaContext {
                         value: v.to_owned(),
                     })
                     .collect::<Vec<_>>(),
-                optional_session_id: None,
             })
             .await
             .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index dea64227..e596c1a7 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -536,15 +536,55 @@ message ExecuteQueryParams {
   repeated KeyValuePair settings = 4;
 }
 
+message CreateSessionParams {
+  repeated KeyValuePair settings = 1;
+}
+
+message CreateSessionResult {
+  string session_id = 1;
+}
+
+message UpdateSessionParams {
+  string session_id = 1;
+  repeated KeyValuePair settings = 2;
+}
+
+message UpdateSessionResult {
+  bool success = 1;
+}
+
+message RemoveSessionParams {
+  string session_id = 1;
+}
+
+message RemoveSessionResult {
+  bool success = 1;
+}
+
 message ExecuteSqlParams {
   string sql = 1;
 }
 
 message ExecuteQueryResult {
+  oneof result {
+    ExecuteQuerySuccessResult success = 1;
+    ExecuteQueryFailureResult failure = 2;
+  }
+}
+
+message ExecuteQuerySuccessResult {
   string job_id = 1;
   string session_id = 2;
 }
 
+message ExecuteQueryFailureResult {
+  oneof failure {
+    string session_not_found = 1;
+    string plan_parsing_failure = 2;
+    string sql_parsing_failure = 3;
+  }
+}
+
 message GetJobStatusParams {
   string job_id = 1;
 }
@@ -676,6 +716,12 @@ service SchedulerGrpc {
 
   rpc GetFileMetadata (GetFileMetadataParams) returns (GetFileMetadataResult) 
{}
 
+  rpc CreateSession (CreateSessionParams) returns (CreateSessionResult) {}
+
+  rpc UpdateSession (UpdateSessionParams) returns (UpdateSessionResult) {}
+
+  rpc RemoveSession (RemoveSessionParams) returns (RemoveSessionResult) {}
+
   rpc ExecuteQuery (ExecuteQueryParams) returns (ExecuteQueryResult) {}
 
   rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {}
diff --git a/ballista/core/src/execution_plans/distributed_query.rs 
b/ballista/core/src/execution_plans/distributed_query.rs
index ac461e14..a7d42b82 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -19,9 +19,9 @@ use crate::client::BallistaClient;
 use crate::config::BallistaConfig;
 use crate::serde::protobuf::execute_query_params::OptionalSessionId;
 use crate::serde::protobuf::{
-    execute_query_params::Query, job_status, 
scheduler_grpc_client::SchedulerGrpcClient,
-    ExecuteQueryParams, GetJobStatusParams, GetJobStatusResult, KeyValuePair,
-    PartitionLocation,
+    execute_query_params::Query, execute_query_result, job_status,
+    scheduler_grpc_client::SchedulerGrpcClient, ExecuteQueryParams, 
GetJobStatusParams,
+    GetJobStatusResult, PartitionLocation,
 };
 use crate::utils::create_grpc_client_connection;
 use datafusion::arrow::datatypes::SchemaRef;
@@ -175,15 +175,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for 
DistributedQueryExec<T> {
 
         let query = ExecuteQueryParams {
             query: Some(Query::LogicalPlan(buf)),
-            settings: self
-                .config
-                .settings()
-                .iter()
-                .map(|(k, v)| KeyValuePair {
-                    key: k.to_owned(),
-                    value: v.to_owned(),
-                })
-                .collect::<Vec<_>>(),
+            settings: vec![],
             optional_session_id: Some(OptionalSessionId::SessionId(
                 self.session_id.clone(),
             )),
@@ -242,6 +234,15 @@ async fn execute_query(
         .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
         .into_inner();
 
+    let query_result = match query_result.result.unwrap() {
+        execute_query_result::Result::Success(success_result) => 
success_result,
+        execute_query_result::Result::Failure(failure_result) => {
+            return Err(DataFusionError::Execution(format!(
+                "Fail to execute query due to {failure_result:?}"
+            )));
+        }
+    };
+
     assert_eq!(
         session_id, query_result.session_id,
         "Session id inconsistent between Client and Server side in 
DistributedQueryExec."
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index 9000fc7e..d0c29e0b 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -916,6 +916,44 @@ pub mod execute_query_params {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct CreateSessionParams {
+    #[prost(message, repeated, tag = "1")]
+    pub settings: ::prost::alloc::vec::Vec<KeyValuePair>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct CreateSessionResult {
+    #[prost(string, tag = "1")]
+    pub session_id: ::prost::alloc::string::String,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UpdateSessionParams {
+    #[prost(string, tag = "1")]
+    pub session_id: ::prost::alloc::string::String,
+    #[prost(message, repeated, tag = "2")]
+    pub settings: ::prost::alloc::vec::Vec<KeyValuePair>,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct UpdateSessionResult {
+    #[prost(bool, tag = "1")]
+    pub success: bool,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct RemoveSessionParams {
+    #[prost(string, tag = "1")]
+    pub session_id: ::prost::alloc::string::String,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct RemoveSessionResult {
+    #[prost(bool, tag = "1")]
+    pub success: bool,
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ExecuteSqlParams {
     #[prost(string, tag = "1")]
     pub sql: ::prost::alloc::string::String,
@@ -923,6 +961,23 @@ pub struct ExecuteSqlParams {
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
 pub struct ExecuteQueryResult {
+    #[prost(oneof = "execute_query_result::Result", tags = "1, 2")]
+    pub result: ::core::option::Option<execute_query_result::Result>,
+}
+/// Nested message and enum types in `ExecuteQueryResult`.
+pub mod execute_query_result {
+    #[allow(clippy::derive_partial_eq_without_eq)]
+    #[derive(Clone, PartialEq, ::prost::Oneof)]
+    pub enum Result {
+        #[prost(message, tag = "1")]
+        Success(super::ExecuteQuerySuccessResult),
+        #[prost(message, tag = "2")]
+        Failure(super::ExecuteQueryFailureResult),
+    }
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ExecuteQuerySuccessResult {
     #[prost(string, tag = "1")]
     pub job_id: ::prost::alloc::string::String,
     #[prost(string, tag = "2")]
@@ -930,6 +985,25 @@ pub struct ExecuteQueryResult {
 }
 #[allow(clippy::derive_partial_eq_without_eq)]
 #[derive(Clone, PartialEq, ::prost::Message)]
+pub struct ExecuteQueryFailureResult {
+    #[prost(oneof = "execute_query_failure_result::Failure", tags = "1, 2, 3")]
+    pub failure: ::core::option::Option<execute_query_failure_result::Failure>,
+}
+/// Nested message and enum types in `ExecuteQueryFailureResult`.
+pub mod execute_query_failure_result {
+    #[allow(clippy::derive_partial_eq_without_eq)]
+    #[derive(Clone, PartialEq, ::prost::Oneof)]
+    pub enum Failure {
+        #[prost(string, tag = "1")]
+        SessionNotFound(::prost::alloc::string::String),
+        #[prost(string, tag = "2")]
+        PlanParsingFailure(::prost::alloc::string::String),
+        #[prost(string, tag = "3")]
+        SqlParsingFailure(::prost::alloc::string::String),
+    }
+}
+#[allow(clippy::derive_partial_eq_without_eq)]
+#[derive(Clone, PartialEq, ::prost::Message)]
 pub struct GetJobStatusParams {
     #[prost(string, tag = "1")]
     pub job_id: ::prost::alloc::string::String,
@@ -1339,6 +1413,87 @@ pub mod scheduler_grpc_client {
                 );
             self.inner.unary(req, path, codec).await
         }
+        pub async fn create_session(
+            &mut self,
+            request: impl tonic::IntoRequest<super::CreateSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::CreateSessionResult>,
+            tonic::Status,
+        > {
+            self.inner
+                .ready()
+                .await
+                .map_err(|e| {
+                    tonic::Status::new(
+                        tonic::Code::Unknown,
+                        format!("Service was not ready: {}", e.into()),
+                    )
+                })?;
+            let codec = tonic::codec::ProstCodec::default();
+            let path = http::uri::PathAndQuery::from_static(
+                "/ballista.protobuf.SchedulerGrpc/CreateSession",
+            );
+            let mut req = request.into_request();
+            req.extensions_mut()
+                .insert(
+                    GrpcMethod::new("ballista.protobuf.SchedulerGrpc", 
"CreateSession"),
+                );
+            self.inner.unary(req, path, codec).await
+        }
+        pub async fn update_session(
+            &mut self,
+            request: impl tonic::IntoRequest<super::UpdateSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::UpdateSessionResult>,
+            tonic::Status,
+        > {
+            self.inner
+                .ready()
+                .await
+                .map_err(|e| {
+                    tonic::Status::new(
+                        tonic::Code::Unknown,
+                        format!("Service was not ready: {}", e.into()),
+                    )
+                })?;
+            let codec = tonic::codec::ProstCodec::default();
+            let path = http::uri::PathAndQuery::from_static(
+                "/ballista.protobuf.SchedulerGrpc/UpdateSession",
+            );
+            let mut req = request.into_request();
+            req.extensions_mut()
+                .insert(
+                    GrpcMethod::new("ballista.protobuf.SchedulerGrpc", 
"UpdateSession"),
+                );
+            self.inner.unary(req, path, codec).await
+        }
+        pub async fn remove_session(
+            &mut self,
+            request: impl tonic::IntoRequest<super::RemoveSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::RemoveSessionResult>,
+            tonic::Status,
+        > {
+            self.inner
+                .ready()
+                .await
+                .map_err(|e| {
+                    tonic::Status::new(
+                        tonic::Code::Unknown,
+                        format!("Service was not ready: {}", e.into()),
+                    )
+                })?;
+            let codec = tonic::codec::ProstCodec::default();
+            let path = http::uri::PathAndQuery::from_static(
+                "/ballista.protobuf.SchedulerGrpc/RemoveSession",
+            );
+            let mut req = request.into_request();
+            req.extensions_mut()
+                .insert(
+                    GrpcMethod::new("ballista.protobuf.SchedulerGrpc", 
"RemoveSession"),
+                );
+            self.inner.unary(req, path, codec).await
+        }
         pub async fn execute_query(
             &mut self,
             request: impl tonic::IntoRequest<super::ExecuteQueryParams>,
@@ -1734,6 +1889,27 @@ pub mod scheduler_grpc_server {
             tonic::Response<super::GetFileMetadataResult>,
             tonic::Status,
         >;
+        async fn create_session(
+            &self,
+            request: tonic::Request<super::CreateSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::CreateSessionResult>,
+            tonic::Status,
+        >;
+        async fn update_session(
+            &self,
+            request: tonic::Request<super::UpdateSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::UpdateSessionResult>,
+            tonic::Status,
+        >;
+        async fn remove_session(
+            &self,
+            request: tonic::Request<super::RemoveSessionParams>,
+        ) -> std::result::Result<
+            tonic::Response<super::RemoveSessionResult>,
+            tonic::Status,
+        >;
         async fn execute_query(
             &self,
             request: tonic::Request<super::ExecuteQueryParams>,
@@ -2075,6 +2251,144 @@ pub mod scheduler_grpc_server {
                     };
                     Box::pin(fut)
                 }
+                "/ballista.protobuf.SchedulerGrpc/CreateSession" => {
+                    #[allow(non_camel_case_types)]
+                    struct CreateSessionSvc<T: SchedulerGrpc>(pub Arc<T>);
+                    impl<
+                        T: SchedulerGrpc,
+                    > tonic::server::UnaryService<super::CreateSessionParams>
+                    for CreateSessionSvc<T> {
+                        type Response = super::CreateSessionResult;
+                        type Future = BoxFuture<
+                            tonic::Response<Self::Response>,
+                            tonic::Status,
+                        >;
+                        fn call(
+                            &mut self,
+                            request: 
tonic::Request<super::CreateSessionParams>,
+                        ) -> Self::Future {
+                            let inner = Arc::clone(&self.0);
+                            let fut = async move {
+                                (*inner).create_session(request).await
+                            };
+                            Box::pin(fut)
+                        }
+                    }
+                    let accept_compression_encodings = 
self.accept_compression_encodings;
+                    let send_compression_encodings = 
self.send_compression_encodings;
+                    let max_decoding_message_size = 
self.max_decoding_message_size;
+                    let max_encoding_message_size = 
self.max_encoding_message_size;
+                    let inner = self.inner.clone();
+                    let fut = async move {
+                        let inner = inner.0;
+                        let method = CreateSessionSvc(inner);
+                        let codec = tonic::codec::ProstCodec::default();
+                        let mut grpc = tonic::server::Grpc::new(codec)
+                            .apply_compression_config(
+                                accept_compression_encodings,
+                                send_compression_encodings,
+                            )
+                            .apply_max_message_size_config(
+                                max_decoding_message_size,
+                                max_encoding_message_size,
+                            );
+                        let res = grpc.unary(method, req).await;
+                        Ok(res)
+                    };
+                    Box::pin(fut)
+                }
+                "/ballista.protobuf.SchedulerGrpc/UpdateSession" => {
+                    #[allow(non_camel_case_types)]
+                    struct UpdateSessionSvc<T: SchedulerGrpc>(pub Arc<T>);
+                    impl<
+                        T: SchedulerGrpc,
+                    > tonic::server::UnaryService<super::UpdateSessionParams>
+                    for UpdateSessionSvc<T> {
+                        type Response = super::UpdateSessionResult;
+                        type Future = BoxFuture<
+                            tonic::Response<Self::Response>,
+                            tonic::Status,
+                        >;
+                        fn call(
+                            &mut self,
+                            request: 
tonic::Request<super::UpdateSessionParams>,
+                        ) -> Self::Future {
+                            let inner = Arc::clone(&self.0);
+                            let fut = async move {
+                                (*inner).update_session(request).await
+                            };
+                            Box::pin(fut)
+                        }
+                    }
+                    let accept_compression_encodings = 
self.accept_compression_encodings;
+                    let send_compression_encodings = 
self.send_compression_encodings;
+                    let max_decoding_message_size = 
self.max_decoding_message_size;
+                    let max_encoding_message_size = 
self.max_encoding_message_size;
+                    let inner = self.inner.clone();
+                    let fut = async move {
+                        let inner = inner.0;
+                        let method = UpdateSessionSvc(inner);
+                        let codec = tonic::codec::ProstCodec::default();
+                        let mut grpc = tonic::server::Grpc::new(codec)
+                            .apply_compression_config(
+                                accept_compression_encodings,
+                                send_compression_encodings,
+                            )
+                            .apply_max_message_size_config(
+                                max_decoding_message_size,
+                                max_encoding_message_size,
+                            );
+                        let res = grpc.unary(method, req).await;
+                        Ok(res)
+                    };
+                    Box::pin(fut)
+                }
+                "/ballista.protobuf.SchedulerGrpc/RemoveSession" => {
+                    #[allow(non_camel_case_types)]
+                    struct RemoveSessionSvc<T: SchedulerGrpc>(pub Arc<T>);
+                    impl<
+                        T: SchedulerGrpc,
+                    > tonic::server::UnaryService<super::RemoveSessionParams>
+                    for RemoveSessionSvc<T> {
+                        type Response = super::RemoveSessionResult;
+                        type Future = BoxFuture<
+                            tonic::Response<Self::Response>,
+                            tonic::Status,
+                        >;
+                        fn call(
+                            &mut self,
+                            request: 
tonic::Request<super::RemoveSessionParams>,
+                        ) -> Self::Future {
+                            let inner = Arc::clone(&self.0);
+                            let fut = async move {
+                                (*inner).remove_session(request).await
+                            };
+                            Box::pin(fut)
+                        }
+                    }
+                    let accept_compression_encodings = 
self.accept_compression_encodings;
+                    let send_compression_encodings = 
self.send_compression_encodings;
+                    let max_decoding_message_size = 
self.max_decoding_message_size;
+                    let max_encoding_message_size = 
self.max_encoding_message_size;
+                    let inner = self.inner.clone();
+                    let fut = async move {
+                        let inner = inner.0;
+                        let method = RemoveSessionSvc(inner);
+                        let codec = tonic::codec::ProstCodec::default();
+                        let mut grpc = tonic::server::Grpc::new(codec)
+                            .apply_compression_config(
+                                accept_compression_encodings,
+                                send_compression_encodings,
+                            )
+                            .apply_max_message_size_config(
+                                max_decoding_message_size,
+                                max_encoding_message_size,
+                            );
+                        let res = grpc.unary(method, req).await;
+                        Ok(res)
+                    };
+                    Box::pin(fut)
+                }
                 "/ballista.protobuf.SchedulerGrpc/ExecuteQuery" => {
                     #[allow(non_camel_case_types)]
                     struct ExecuteQuerySvc<T: SchedulerGrpc>(pub Arc<T>);
diff --git a/ballista/scheduler/src/cluster/kv.rs 
b/ballista/scheduler/src/cluster/kv.rs
index eb164753..effbc0cc 100644
--- a/ballista/scheduler/src/cluster/kv.rs
+++ b/ballista/scheduler/src/cluster/kv.rs
@@ -753,6 +753,17 @@ impl<S: KeyValueStore, T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
 
         Ok(create_datafusion_context(config, self.session_builder))
     }
+
+    async fn remove_session(
+        &self,
+        session_id: &str,
+    ) -> Result<Option<Arc<SessionContext>>> {
+        let session_ctx = self.get_session(session_id).await.ok();
+
+        self.store.delete(Keyspace::Sessions, session_id).await?;
+
+        Ok(session_ctx)
+    }
 }
 
 async fn with_lock<Out, F: Future<Output = Out>>(mut lock: Box<dyn Lock>, op: 
F) -> Out {
diff --git a/ballista/scheduler/src/cluster/memory.rs 
b/ballista/scheduler/src/cluster/memory.rs
index 6d7d7bc6..9a8a6fc9 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -407,6 +407,13 @@ impl JobState for InMemoryJobState {
         Ok(session)
     }
 
+    async fn remove_session(
+        &self,
+        session_id: &str,
+    ) -> Result<Option<Arc<SessionContext>>> {
+        Ok(self.sessions.remove(session_id).map(|(_key, value)| value))
+    }
+
     async fn job_state_events(&self) -> Result<JobStateEventStream> {
         Ok(Box::pin(self.job_event_sender.subscribe()))
     }
diff --git a/ballista/scheduler/src/cluster/mod.rs 
b/ballista/scheduler/src/cluster/mod.rs
index 8bea6d77..0d8ae900 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -369,6 +369,11 @@ pub trait JobState: Send + Sync {
         session_id: &str,
         config: &BallistaConfig,
     ) -> Result<Arc<SessionContext>>;
+
+    async fn remove_session(
+        &self,
+        session_id: &str,
+    ) -> Result<Option<Arc<SessionContext>>>;
 }
 
 pub(crate) fn reserve_slots_bias(
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs 
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 28d3f9cc..485afa8c 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -17,16 +17,20 @@
 
 use ballista_core::config::{BallistaConfig, BALLISTA_JOB_NAME};
 use ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, 
Query};
+use std::collections::HashMap;
 use std::convert::TryInto;
 
 use ballista_core::serde::protobuf::executor_registration::OptionalHost;
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
 use ballista_core::serde::protobuf::{
-    CancelJobParams, CancelJobResult, CleanJobDataParams, CleanJobDataResult,
-    ExecuteQueryParams, ExecuteQueryResult, ExecutorHeartbeat, 
ExecutorStoppedParams,
+    execute_query_failure_result, execute_query_result, CancelJobParams, 
CancelJobResult,
+    CleanJobDataParams, CleanJobDataResult, CreateSessionParams, 
CreateSessionResult,
+    ExecuteQueryFailureResult, ExecuteQueryParams, ExecuteQueryResult,
+    ExecuteQuerySuccessResult, ExecutorHeartbeat, ExecutorStoppedParams,
     ExecutorStoppedResult, GetFileMetadataParams, GetFileMetadataResult,
     GetJobStatusParams, GetJobStatusResult, HeartBeatParams, HeartBeatResult,
     PollWorkParams, PollWorkResult, RegisterExecutorParams, 
RegisterExecutorResult,
+    RemoveSessionParams, RemoveSessionResult, UpdateSessionParams, 
UpdateSessionResult,
     UpdateTaskStatusParams, UpdateTaskStatusResult,
 };
 use ballista_core::serde::scheduler::ExecutorMetadata;
@@ -324,6 +328,82 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         }))
     }
 
+    async fn create_session(
+        &self,
+        request: Request<CreateSessionParams>,
+    ) -> Result<Response<CreateSessionResult>, Status> {
+        let session_params = request.into_inner();
+        // parse config
+        let mut config_builder = BallistaConfig::builder();
+        for kv_pair in &session_params.settings {
+            config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
+        }
+        let config = config_builder.build().map_err(|e| {
+            let msg = format!("Could not parse configs: {e}");
+            error!("{}", msg);
+            Status::internal(msg)
+        })?;
+
+        let ctx = self
+            .state
+            .session_manager
+            .create_session(&config)
+            .await
+            .map_err(|e| {
+                Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
+            })?;
+
+        Ok(Response::new(CreateSessionResult {
+            session_id: ctx.session_id(),
+        }))
+    }
+
+    async fn update_session(
+        &self,
+        request: Request<UpdateSessionParams>,
+    ) -> Result<Response<UpdateSessionResult>, Status> {
+        let session_params = request.into_inner();
+        // parse config
+        let mut config_builder = BallistaConfig::builder();
+        for kv_pair in &session_params.settings {
+            config_builder = config_builder.set(&kv_pair.key, &kv_pair.value);
+        }
+        let config = config_builder.build().map_err(|e| {
+            let msg = format!("Could not parse configs: {e}");
+            error!("{}", msg);
+            Status::internal(msg)
+        })?;
+
+        self.state
+            .session_manager
+            .update_session(&session_params.session_id, &config)
+            .await
+            .map_err(|e| {
+                Status::internal(format!("Failed to create SessionContext: 
{e:?}"))
+            })?;
+
+        Ok(Response::new(UpdateSessionResult { success: true }))
+    }
+
+    async fn remove_session(
+        &self,
+        request: Request<RemoveSessionParams>,
+    ) -> Result<Response<RemoveSessionResult>, Status> {
+        let session_params = request.into_inner();
+        self.state
+            .session_manager
+            .remove_session(&session_params.session_id)
+            .await
+            .map_err(|e| {
+                Status::internal(format!(
+                    "Failed to remove SessionContext: {e:?} for session {}",
+                    session_params.session_id
+                ))
+            })?;
+
+        Ok(Response::new(RemoveSessionResult { success: true }))
+    }
+
     async fn execute_query(
         &self,
         request: Request<ExecuteQueryParams>,
@@ -331,36 +411,39 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         let query_params = request.into_inner();
         if let ExecuteQueryParams {
             query: Some(query),
-            settings,
             optional_session_id,
+            settings,
         } = query_params
         {
-            // parse config
-            let mut config_builder = BallistaConfig::builder();
-            for kv_pair in &settings {
-                config_builder = config_builder.set(&kv_pair.key, 
&kv_pair.value);
+            let mut query_settings = HashMap::new();
+            for kv_pair in settings {
+                query_settings.insert(kv_pair.key, kv_pair.value);
             }
-            let config = config_builder.build().map_err(|e| {
-                let msg = format!("Could not parse configs: {e}");
-                error!("{}", msg);
-                Status::internal(msg)
-            })?;
 
             let (session_id, session_ctx) = match optional_session_id {
                 Some(OptionalSessionId::SessionId(session_id)) => {
-                    let ctx = self
-                        .state
-                        .session_manager
-                        .update_session(&session_id, &config)
-                        .await
-                        .map_err(|e| {
-                            Status::internal(format!(
-                                "Failed to load SessionContext for session ID 
{session_id}: {e:?}"
-                            ))
-                        })?;
-                    (session_id, ctx)
+                    match 
self.state.session_manager.get_session(&session_id).await {
+                        Ok(ctx) => (session_id, ctx),
+                        Err(e) => {
+                            let msg = format!("Failed to load SessionContext 
for session ID {session_id}: {e}");
+                            error!("{}", msg);
+                            return Ok(Response::new(ExecuteQueryResult {
+                                result: 
Some(execute_query_result::Result::Failure(
+                                    ExecuteQueryFailureResult {
+                                        failure: 
Some(execute_query_failure_result::Failure::SessionNotFound(msg)),
+                                    },
+                                )),
+                            }));
+                        }
+                    }
                 }
                 _ => {
+                    // Create default config
+                    let config = BallistaConfig::builder().build().map_err(|e| 
{
+                        let msg = format!("Could not parse configs: {e}");
+                        error!("{}", msg);
+                        Status::internal(msg)
+                    })?;
                     let ctx = self
                         .state
                         .session_manager
@@ -377,37 +460,57 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
             };
 
             let plan = match query {
-                Query::LogicalPlan(message) => 
T::try_decode(message.as_slice())
-                    .and_then(|m| {
+                Query::LogicalPlan(message) => {
+                    match T::try_decode(message.as_slice()).and_then(|m| {
                         m.try_into_logical_plan(
                             session_ctx.deref(),
                             self.state.codec.logical_extension_codec(),
                         )
-                    })
-                    .map_err(|e| {
-                        let msg = format!("Could not parse logical plan 
protobuf: {e}");
-                        error!("{}", msg);
-                        Status::internal(msg)
-                    })?,
-                Query::Sql(sql) => session_ctx
-                    .sql(&sql)
-                    .await
-                    .and_then(|df| df.into_optimized_plan())
-                    .map_err(|e| {
-                        let msg = format!("Error parsing SQL: {e}");
-                        error!("{}", msg);
-                        Status::internal(msg)
-                    })?,
+                    }) {
+                        Ok(plan) => plan,
+                        Err(e) => {
+                            let msg =
+                                format!("Could not parse logical plan 
protobuf: {e}");
+                            error!("{}", msg);
+                            return Ok(Response::new(ExecuteQueryResult {
+                                result: 
Some(execute_query_result::Result::Failure(
+                                    ExecuteQueryFailureResult {
+                                        failure: 
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
+                                    },
+                                )),
+                            }));
+                        }
+                    }
+                }
+                Query::Sql(sql) => {
+                    match session_ctx
+                        .sql(&sql)
+                        .await
+                        .and_then(|df| df.into_optimized_plan())
+                    {
+                        Ok(plan) => plan,
+                        Err(e) => {
+                            let msg = format!("Error parsing SQL: {e}");
+                            error!("{}", msg);
+                            return Ok(Response::new(ExecuteQueryResult {
+                                result: 
Some(execute_query_result::Result::Failure(
+                                    ExecuteQueryFailureResult {
+                                        failure: 
Some(execute_query_failure_result::Failure::PlanParsingFailure(msg)),
+                                    },
+                                )),
+                            }));
+                        }
+                    }
+                }
             };
 
             debug!("Received plan for execution: {:?}", plan);
 
             let job_id = self.state.task_manager.generate_job_id();
-            let job_name = config
-                .settings()
+            let job_name = query_settings
                 .get(BALLISTA_JOB_NAME)
                 .cloned()
-                .unwrap_or_default();
+                .unwrap_or_else(|| "None".to_string());
 
             self.submit_job(&job_id, &job_name, session_ctx, &plan)
                 .await
@@ -419,37 +522,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                     Status::internal(msg)
                 })?;
 
-            Ok(Response::new(ExecuteQueryResult { job_id, session_id }))
-        } else if let ExecuteQueryParams {
-            query: None,
-            settings,
-            optional_session_id: None,
-        } = query_params
-        {
-            // parse config for new session
-            let mut config_builder = BallistaConfig::builder();
-            for kv_pair in &settings {
-                config_builder = config_builder.set(&kv_pair.key, 
&kv_pair.value);
-            }
-            let config = config_builder.build().map_err(|e| {
-                let msg = format!("Could not parse configs: {e}");
-                error!("{}", msg);
-                Status::internal(msg)
-            })?;
-            let session = self
-                .state
-                .session_manager
-                .create_session(&config)
-                .await
-                .map_err(|e| {
-                    Status::internal(format!(
-                        "Failed to create new SessionContext: {e:?}"
-                    ))
-                })?;
-
             Ok(Response::new(ExecuteQueryResult {
-                job_id: "NA".to_owned(),
-                session_id: session.session_id(),
+                result: Some(execute_query_result::Result::Success(
+                    ExecuteQuerySuccessResult { job_id, session_id },
+                )),
             }))
         } else {
             Err(Status::internal("Error parsing request"))
diff --git a/ballista/scheduler/src/state/session_manager.rs 
b/ballista/scheduler/src/state/session_manager.rs
index e89f6dae..e6f99603 100644
--- a/ballista/scheduler/src/state/session_manager.rs
+++ b/ballista/scheduler/src/state/session_manager.rs
@@ -33,6 +33,13 @@ impl SessionManager {
         Self { state }
     }
 
+    pub async fn remove_session(
+        &self,
+        session_id: &str,
+    ) -> Result<Option<Arc<SessionContext>>> {
+        self.state.remove_session(session_id).await
+    }
+
     pub async fn update_session(
         &self,
         session_id: &str,


Reply via email to