This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new 850f9d61 Update task status to the its job curator scheduler (#181)
850f9d61 is described below

commit 850f9d61abc2b898c278bdbd4b95d1183c4e7ec0
Author: yahoNanJing <[email protected]>
AuthorDate: Fri Sep 2 00:37:59 2022 +0800

    Update task status to the its job curator scheduler (#181)
    
    Co-authored-by: yangzhong <[email protected]>
---
 ballista/rust/core/proto/ballista.proto           |   3 +-
 ballista/rust/executor/src/executor_server.rs     | 220 +++++++++++++++++-----
 ballista/rust/scheduler/src/state/mod.rs          |   6 +-
 ballista/rust/scheduler/src/state/task_manager.rs |  47 ++---
 4 files changed, 192 insertions(+), 84 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index f3b703c6..ee03d556 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -807,7 +807,8 @@ message CancelJobResult {
 
 message LaunchTaskParams {
   // Allow to launch a task set to an executor at once
-  repeated TaskDefinition task = 1;
+  repeated TaskDefinition tasks = 1;
+  string scheduler_id = 2;
 }
 
 message LaunchTaskResult {
diff --git a/ballista/rust/executor/src/executor_server.rs 
b/ballista/rust/executor/src/executor_server.rs
index dc2724fd..f00ca7f0 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -21,9 +21,9 @@ use std::convert::TryInto;
 use std::ops::Deref;
 use std::sync::Arc;
 use std::time::{Duration, SystemTime, UNIX_EPOCH};
-use tokio::sync::mpsc;
+use tokio::sync::{mpsc, RwLock};
 
-use log::{debug, error, info};
+use log::{debug, error, info, warn};
 use tonic::transport::Channel;
 use tonic::{Request, Response, Status};
 
@@ -41,7 +41,9 @@ use ballista_core::serde::protobuf::{
 };
 use ballista_core::serde::scheduler::ExecutorState;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
-use ballista_core::utils::{collect_plan_metrics, create_grpc_server};
+use ballista_core::utils::{
+    collect_plan_metrics, create_grpc_client_connection, create_grpc_server,
+};
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::ExecutionPlan;
 use datafusion_proto::logical_plan::AsLogicalPlan;
@@ -54,6 +56,21 @@ use crate::executor::Executor;
 use crate::shutdown::ShutdownNotifier;
 
 type ServerHandle = JoinHandle<Result<(), BallistaError>>;
+type SchedulerClients = Arc<RwLock<HashMap<String, 
SchedulerGrpcClient<Channel>>>>;
+
+/// Wrap TaskDefinition with its curator scheduler id for task update to its 
specific curator scheduler later
+#[derive(Debug)]
+struct CuratorTaskDefinition {
+    scheduler_id: String,
+    task: TaskDefinition,
+}
+
+/// Wrap TaskStatus with its curator scheduler id for task update to its 
specific curator scheduler later
+#[derive(Debug)]
+struct CuratorTaskStatus {
+    scheduler_id: String,
+    task_status: TaskStatus,
+}
 
 pub async fn startup<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
     mut scheduler: SchedulerGrpcClient<Channel>,
@@ -62,9 +79,10 @@ pub async fn startup<T: 'static + AsLogicalPlan, U: 'static 
+ AsExecutionPlan>(
     stop_send: mpsc::Sender<bool>,
     shutdown_noti: &ShutdownNotifier,
 ) -> Result<ServerHandle, BallistaError> {
-    // TODO make the buffer size configurable
-    let (tx_task, rx_task) = mpsc::channel::<TaskDefinition>(1000);
-    let (tx_task_status, rx_task_status) = mpsc::channel::<TaskStatus>(1000);
+    let channel_buf_size = executor.concurrent_tasks * 50;
+    let (tx_task, rx_task) = 
mpsc::channel::<CuratorTaskDefinition>(channel_buf_size);
+    let (tx_task_status, rx_task_status) =
+        mpsc::channel::<CuratorTaskStatus>(channel_buf_size);
 
     let executor_server = ExecutorServer::new(
         scheduler.clone(),
@@ -163,17 +181,18 @@ async fn register_executor(
 pub struct ExecutorServer<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> {
     _start_time: u128,
     executor: Arc<Executor>,
-    scheduler: SchedulerGrpcClient<Channel>,
     executor_env: ExecutorEnv,
     codec: BallistaCodec<T, U>,
+    scheduler_to_register: SchedulerGrpcClient<Channel>,
+    schedulers: SchedulerClients,
 }
 
 #[derive(Clone)]
 struct ExecutorEnv {
     /// Receive `TaskDefinition` from rpc then send to CPU bound tasks pool 
`dedicated_executor`.
-    tx_task: mpsc::Sender<TaskDefinition>,
+    tx_task: mpsc::Sender<CuratorTaskDefinition>,
     /// Receive `TaskStatus` from CPU bound tasks pool `dedicated_executor` 
then use rpc send back to scheduler.
-    tx_task_status: mpsc::Sender<TaskStatus>,
+    tx_task_status: mpsc::Sender<CuratorTaskStatus>,
     /// Receive stop executor request from rpc.
     tx_stop: mpsc::Sender<bool>,
 }
@@ -182,7 +201,7 @@ unsafe impl Sync for ExecutorEnv {}
 
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> 
ExecutorServer<T, U> {
     fn new(
-        scheduler: SchedulerGrpcClient<Channel>,
+        scheduler_to_register: SchedulerGrpcClient<Channel>,
         executor: Arc<Executor>,
         executor_env: ExecutorEnv,
         codec: BallistaCodec<T, U>,
@@ -193,25 +212,86 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
                 .unwrap()
                 .as_millis(),
             executor,
-            scheduler,
             executor_env,
             codec,
+            scheduler_to_register,
+            schedulers: Default::default(),
         }
     }
 
+    async fn get_scheduler_client(
+        &self,
+        scheduler_id: &str,
+    ) -> Result<SchedulerGrpcClient<Channel>, BallistaError> {
+        let scheduler = {
+            let schedulers = self.schedulers.read().await;
+            schedulers.get(scheduler_id).cloned()
+        };
+        // If channel does not exist, create a new one
+        if let Some(scheduler) = scheduler {
+            Ok(scheduler)
+        } else {
+            let scheduler_url = format!("http://{}";, scheduler_id);
+            let connection = 
create_grpc_client_connection(scheduler_url).await?;
+            let scheduler = SchedulerGrpcClient::new(connection);
+
+            {
+                let mut schedulers = self.schedulers.write().await;
+                schedulers.insert(scheduler_id.to_owned(), scheduler.clone());
+            }
+
+            Ok(scheduler)
+        }
+    }
+
+    /// 1. First Heartbeat to its registration scheduler, if successful then 
return; else go next.
+    /// 2. Heartbeat to schedulers which has launching tasks to this executor 
until one succeeds
     async fn heartbeat(&self) {
-        // TODO Error handling
-        self.scheduler
-            .clone()
-            .heart_beat_from_executor(HeartBeatParams {
-                executor_id: self.executor.metadata.id.clone(),
-                state: Some(self.get_executor_state().into()),
-            })
+        let heartbeat_params = HeartBeatParams {
+            executor_id: self.executor.metadata.id.clone(),
+            state: Some(self.get_executor_state().into()),
+        };
+        let mut scheduler = self.scheduler_to_register.clone();
+        match scheduler
+            .heart_beat_from_executor(heartbeat_params.clone())
             .await
-            .unwrap();
+        {
+            Ok(_) => {
+                return;
+            }
+            Err(e) => {
+                warn!(
+                    "Fail to update heartbeat to its registration scheduler 
due to {:?}",
+                    e
+                );
+            }
+        };
+
+        let schedulers = self.schedulers.read().await.clone();
+        for (scheduler_id, mut scheduler) in schedulers {
+            match scheduler
+                .heart_beat_from_executor(heartbeat_params.clone())
+                .await
+            {
+                Ok(_) => {
+                    break;
+                }
+                Err(e) => {
+                    warn!(
+                        "Fail to update heartbeat to scheduler {} due to {:?}",
+                        scheduler_id, e
+                    );
+                }
+            }
+        }
     }
 
-    async fn run_task(&self, task: TaskDefinition) -> Result<(), 
BallistaError> {
+    async fn run_task(
+        &self,
+        curator_task: CuratorTaskDefinition,
+    ) -> Result<(), BallistaError> {
+        let scheduler_id = curator_task.scheduler_id;
+        let task = curator_task.task;
         let task_id = task.task_id.unwrap();
         let task_id_log = format!(
             "{}/{}/{}",
@@ -295,7 +375,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         );
 
         let task_status_sender = self.executor_env.tx_task_status.clone();
-        task_status_sender.send(task_status).await.unwrap();
+        task_status_sender
+            .send(CuratorTaskStatus {
+                scheduler_id,
+                task_status,
+            })
+            .await
+            .unwrap();
         Ok(())
     }
 
@@ -354,8 +440,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
 
     fn start(
         &self,
-        mut rx_task: mpsc::Receiver<TaskDefinition>,
-        mut rx_task_status: mpsc::Receiver<TaskStatus>,
+        mut rx_task: mpsc::Receiver<CuratorTaskDefinition>,
+        mut rx_task_status: mpsc::Receiver<CuratorTaskStatus>,
         shutdown_noti: &ShutdownNotifier,
     ) {
         //1. loop for task status reporting
@@ -366,9 +452,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
             info!("Starting the task status reporter");
             // As long as the shutdown notification has not been received
             while !tasks_status_shutdown.is_shutdown() {
-                let mut tasks_status = vec![];
+                let mut curator_task_status_map: HashMap<String, 
Vec<TaskStatus>> =
+                    HashMap::new();
                 // First try to fetch task status from the channel in 
*blocking* mode
-                let maybe_task_status = tokio::select! {
+                let maybe_task_status: Option<CuratorTaskStatus> = 
tokio::select! {
                      task_status = rx_task_status.recv() => task_status,
                     _ = tasks_status_shutdown.recv() => {
                         info!("Stop task status reporting loop");
@@ -377,8 +464,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                     }
                 };
 
+                let mut fetched_task_num = 0usize;
                 if let Some(task_status) = maybe_task_status {
-                    tasks_status.push(task_status);
+                    let task_status_vec = curator_task_status_map
+                        .entry(task_status.scheduler_id)
+                        .or_insert_with(Vec::new);
+                    task_status_vec.push(task_status.task_status);
+                    fetched_task_num += 1;
                 } else {
                     info!("Channel is closed and will exit the task status 
report loop.");
                     drop(tasks_status_complete);
@@ -388,14 +480,15 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                 // Then try to fetch by non-blocking mode to fetch as much 
finished tasks as possible
                 loop {
                     match rx_task_status.try_recv() {
-                        Ok(task_sta) => {
-                            tasks_status.push(task_sta);
+                        Ok(task_status) => {
+                            let task_status_vec = curator_task_status_map
+                                .entry(task_status.scheduler_id)
+                                .or_insert_with(Vec::new);
+                            task_status_vec.push(task_status.task_status);
+                            fetched_task_num += 1;
                         }
                         Err(TryRecvError::Empty) => {
-                            info!(
-                                "Fetched {} tasks status to report",
-                                tasks_status.len()
-                            );
+                            info!("Fetched {} tasks status to report", 
fetched_task_num);
                             break;
                         }
                         Err(TryRecvError::Disconnected) => {
@@ -406,16 +499,33 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                     }
                 }
 
-                if let Err(e) = executor_server
-                    .scheduler
-                    .clone()
-                    .update_task_status(UpdateTaskStatusParams {
-                        executor_id: 
executor_server.executor.metadata.id.clone(),
-                        task_status: tasks_status.clone(),
-                    })
-                    .await
-                {
-                    error!("Fail to update tasks {:?} due to {:?}", 
tasks_status, e);
+                for (scheduler_id, tasks_status) in 
curator_task_status_map.into_iter() {
+                    match 
executor_server.get_scheduler_client(&scheduler_id).await {
+                        Ok(mut scheduler) => {
+                            if let Err(e) = scheduler
+                                .update_task_status(UpdateTaskStatusParams {
+                                    executor_id: executor_server
+                                        .executor
+                                        .metadata
+                                        .id
+                                        .clone(),
+                                    task_status: tasks_status.clone(),
+                                })
+                                .await
+                            {
+                                error!(
+                                    "Fail to update tasks {:?} due to {:?}",
+                                    tasks_status, e
+                                );
+                            }
+                        }
+                        Err(e) => {
+                            error!(
+                                "Fail to connect to scheduler {} due to {:?}",
+                                scheduler_id, e
+                            );
+                        }
+                    }
                 }
             }
         });
@@ -436,7 +546,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
 
             // As long as the shutdown notification has not been received
             while !task_runner_shutdown.is_shutdown() {
-                let maybe_task = tokio::select! {
+                let maybe_task: Option<CuratorTaskDefinition> = tokio::select! 
{
                      task = rx_task.recv() => task,
                     _ = task_runner_shutdown.recv() => {
                         info!("Stop the task runner pool");
@@ -444,8 +554,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                         return;
                     }
                 };
-                if let Some(task) = maybe_task {
-                    if let Some(task_id) = &task.task_id {
+                if let Some(curator_task) = maybe_task {
+                    if let Some(task_id) = &curator_task.task.task_id {
                         let task_id_log = format!(
                             "{}/{}/{}",
                             task_id.job_id, task_id.stage_id, 
task_id.partition_id
@@ -454,7 +564,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
 
                         let server = executor_server.clone();
                         dedicated_executor.spawn(async move {
-                            server.run_task(task).await.unwrap_or_else(|e| {
+                            
server.run_task(curator_task).await.unwrap_or_else(|e| {
                                 error!(
                                     "Fail to run the task {:?} due to {:?}",
                                     task_id_log, e
@@ -462,7 +572,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskRunnerPool<T,
                             });
                         });
                     } else {
-                        error!("There's no task id in the task definition 
{:?}", task);
+                        error!(
+                            "There's no task id in the task definition {:?}",
+                            curator_task
+                        );
                     }
                 } else {
                     info!("Channel is closed and will exit the task receive 
loop");
@@ -482,10 +595,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorGrpc
         &self,
         request: Request<LaunchTaskParams>,
     ) -> Result<Response<LaunchTaskResult>, Status> {
-        let tasks = request.into_inner().task;
+        let LaunchTaskParams {
+            tasks,
+            scheduler_id,
+        } = request.into_inner();
         let task_sender = self.executor_env.tx_task.clone();
         for task in tasks {
-            task_sender.send(task).await.unwrap();
+            task_sender
+                .send(CuratorTaskDefinition {
+                    scheduler_id: scheduler_id.clone(),
+                    task,
+                })
+                .await
+                .unwrap();
         }
         Ok(Response::new(LaunchTaskResult { success: true }))
     }
diff --git a/ballista/rust/scheduler/src/state/mod.rs 
b/ballista/rust/scheduler/src/state/mod.rs
index 43cb25e6..f45b8e25 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -202,8 +202,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerState<T,
                         .await
                     {
                         Ok(executor) => {
-                            if let Err(e) =
-                                self.task_manager.launch_task(&executor, 
task).await
+                            if let Err(e) = self
+                                .task_manager
+                                .launch_task(&executor, task, 
&self.executor_manager)
+                                .await
                             {
                                 error!("Failed to launch new task: {:?}", e);
                                 unassigned_reservations.push(
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs 
b/ballista/rust/scheduler/src/state/task_manager.rs
index 25357f38..8e205432 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -387,41 +387,23 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         &self,
         executor: &ExecutorMetadata,
         task: Task,
+        executor_manager: &ExecutorManager,
     ) -> Result<()> {
         info!("Launching task {:?} on executor {:?}", task, executor.id);
         let task_definition = self.prepare_task_definition(task)?;
-        let mut clients = self.clients.write().await;
-        if let Some(client) = clients.get_mut(&executor.id) {
-            client
-                .launch_task(protobuf::LaunchTaskParams {
-                    task: vec![task_definition],
-                })
-                .await
-                .map_err(|e| {
-                    BallistaError::Internal(format!(
-                        "Failed to connect to executor {}: {:?}",
-                        executor.id, e
-                    ))
-                })?;
-        } else {
-            let executor_id = executor.id.clone();
-            let executor_url = format!("http://{}:{}";, executor.host, 
executor.grpc_port);
-            let connection =
-                
ballista_core::utils::create_grpc_client_connection(executor_url).await?;
-            let mut client = ExecutorGrpcClient::new(connection);
-            clients.insert(executor_id, client.clone());
-            client
-                .launch_task(protobuf::LaunchTaskParams {
-                    task: vec![task_definition],
-                })
-                .await
-                .map_err(|e| {
-                    BallistaError::Internal(format!(
-                        "Failed to connect to executor {}: {:?}",
-                        executor.id, e
-                    ))
-                })?;
-        }
+        let mut client = executor_manager.get_client(&executor.id).await?;
+        client
+            .launch_task(protobuf::LaunchTaskParams {
+                tasks: vec![task_definition],
+                scheduler_id: self.scheduler_id.clone(),
+            })
+            .await
+            .map_err(|e| {
+                BallistaError::Internal(format!(
+                    "Failed to connect to executor {}: {:?}",
+                    executor.id, e
+                ))
+            })?;
         Ok(())
     }
 
@@ -431,6 +413,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         &self,
         _executor: &ExecutorMetadata,
         _task: Task,
+        _executor_manager: &ExecutorManager,
     ) -> Result<()> {
         Ok(())
     }

Reply via email to