This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new f8885061 Cache encoded stage plan (#393)
f8885061 is described below

commit f88850617b1f3992aa4cd5b9eba61a557ec2e62a
Author: yahoNanJing <[email protected]>
AuthorDate: Thu Oct 20 06:04:01 2022 +0800

    Cache encoded stage plan (#393)
    
    * Cache encoded stage plan
    
    Co-authored-by: yangzhong <[email protected]>
---
 ballista/scheduler/src/state/task_manager.rs | 121 +++++++++++++++++++--------
 1 file changed, 86 insertions(+), 35 deletions(-)

diff --git a/ballista/scheduler/src/state/task_manager.rs 
b/ballista/scheduler/src/state/task_manager.rs
index c73612f5..77b74495 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -24,7 +24,6 @@ use crate::state::execution_graph::{
 use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
 use crate::state::{decode_protobuf, encode_protobuf, with_lock, with_locks};
 use ballista_core::config::BallistaConfig;
-#[cfg(not(test))]
 use ballista_core::error::BallistaError;
 use ballista_core::error::Result;
 
@@ -47,7 +46,7 @@ use std::sync::Arc;
 use std::time::Duration;
 use std::time::{SystemTime, UNIX_EPOCH};
 use tokio::sync::RwLock;
-type ExecutionGraphCache = Arc<DashMap<String, Arc<RwLock<ExecutionGraph>>>>;
+type ActiveJobCache = Arc<DashMap<String, JobInfoCache>>;
 
 // TODO move to configuration file
 /// Default max failure attempts for task level retry
@@ -61,8 +60,25 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionPlan>
     session_builder: SessionBuilder,
     codec: BallistaCodec<T, U>,
     scheduler_id: String,
+    // Cache for active jobs curated by this scheduler
+    active_job_cache: ActiveJobCache,
+}
+
+#[derive(Clone)]
+struct JobInfoCache {
     // Cache for active execution graphs curated by this scheduler
-    active_job_cache: ExecutionGraphCache,
+    execution_graph: Arc<RwLock<ExecutionGraph>>,
+    // Cache for encoded execution stage plan to avoid duplicated encoding for 
multiple tasks
+    encoded_stage_plans: HashMap<usize, Vec<u8>>,
+}
+
+impl JobInfoCache {
+    fn new(graph: ExecutionGraph) -> Self {
+        Self {
+            execution_graph: Arc::new(RwLock::new(graph)),
+            encoded_stage_plans: HashMap::new(),
+        }
+    }
 }
 
 #[derive(Clone)]
@@ -113,7 +129,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
 
         graph.revive();
         self.active_job_cache
-            .insert(job_id.to_owned(), Arc::new(RwLock::new(graph)));
+            .insert(job_id.to_owned(), JobInfoCache::new(graph));
 
         Ok(())
     }
@@ -266,8 +282,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         let mut pending_tasks = 0usize;
         let mut assign_tasks = 0usize;
         for pairs in self.active_job_cache.iter() {
-            let (_job_id, graph) = pairs.pair();
-            let mut graph = graph.write().await;
+            let (_job_id, job_info) = pairs.pair();
+            let mut graph = job_info.execution_graph.write().await;
             for reservation in free_reservations.iter().skip(assign_tasks) {
                 if let Some(task) = 
graph.pop_next_task(&reservation.executor_id)? {
                     assignments.push((reservation.executor_id.clone(), task));
@@ -476,8 +492,8 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         let updated_graphs: DashMap<String, ExecutionGraph> = DashMap::new();
         {
             for pairs in self.active_job_cache.iter() {
-                let (job_id, graph) = pairs.pair();
-                let mut graph = graph.write().await;
+                let (job_id, job_info) = pairs.pair();
+                let mut graph = job_info.execution_graph.write().await;
                 let reset = graph.reset_stages_on_lost_executor(executor_id)?;
                 if !reset.0.is_empty() {
                     updated_graphs.insert(job_id.to_owned(), graph.clone());
@@ -557,31 +573,54 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         task: TaskDescription,
     ) -> Result<TaskDefinition> {
         debug!("Preparing task definition for {:?}", task);
-        let mut plan_buf: Vec<u8> = vec![];
-        let plan_proto =
-            U::try_from_physical_plan(task.plan, 
self.codec.physical_extension_codec())?;
-        plan_proto.try_encode(&mut plan_buf)?;
-
-        let output_partitioning =
-            hash_partitioning_to_proto(task.output_partitioning.as_ref())?;
-
-        let task_definition = TaskDefinition {
-            task_id: task.task_id as u32,
-            task_attempt_num: task.task_attempt as u32,
-            job_id: task.partition.job_id.clone(),
-            stage_id: task.partition.stage_id as u32,
-            stage_attempt_num: task.stage_attempt_num as u32,
-            partition_id: task.partition.partition_id as u32,
-            plan: plan_buf,
-            output_partitioning,
-            session_id: task.session_id,
-            launch_time: SystemTime::now()
-                .duration_since(UNIX_EPOCH)
-                .unwrap()
-                .as_millis() as u64,
-            props: vec![],
-        };
-        Ok(task_definition)
+
+        let job_id = task.partition.job_id.clone();
+        let stage_id = task.partition.stage_id;
+
+        if let Some(mut job_info) = self.active_job_cache.get_mut(&job_id) {
+            let plan = if let Some(plan) = 
job_info.encoded_stage_plans.get(&stage_id) {
+                plan.clone()
+            } else {
+                let mut plan_buf: Vec<u8> = vec![];
+                let plan_proto = U::try_from_physical_plan(
+                    task.plan,
+                    self.codec.physical_extension_codec(),
+                )?;
+                plan_proto.try_encode(&mut plan_buf)?;
+
+                job_info
+                    .encoded_stage_plans
+                    .insert(stage_id, plan_buf.clone());
+
+                plan_buf
+            };
+
+            let output_partitioning =
+                hash_partitioning_to_proto(task.output_partitioning.as_ref())?;
+
+            let task_definition = TaskDefinition {
+                task_id: task.task_id as u32,
+                task_attempt_num: task.task_attempt as u32,
+                job_id,
+                stage_id: stage_id as u32,
+                stage_attempt_num: task.stage_attempt_num as u32,
+                partition_id: task.partition.partition_id as u32,
+                plan,
+                output_partitioning,
+                session_id: task.session_id,
+                launch_time: SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .unwrap()
+                    .as_millis() as u64,
+                props: vec![],
+            };
+            Ok(task_definition)
+        } else {
+            Err(BallistaError::General(format!(
+                "Cannot prepare task definition for job {} which is not in 
active cache",
+                job_id
+            )))
+        }
     }
 
     /// Get the `ExecutionGraph` for the given job ID from cache
@@ -589,7 +628,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         &self,
         job_id: &str,
     ) -> Option<Arc<RwLock<ExecutionGraph>>> {
-        self.active_job_cache.get(job_id).map(|value| value.clone())
+        self.active_job_cache
+            .get(job_id)
+            .map(|value| value.execution_graph.clone())
+    }
+
+    /// Remove the `ExecutionGraph` for the given job ID from cache
+    pub(crate) async fn remove_active_execution_graph(
+        &self,
+        job_id: &str,
+    ) -> Option<Arc<RwLock<ExecutionGraph>>> {
+        self.active_job_cache
+            .remove(job_id)
+            .map(|value| value.1.execution_graph)
     }
 
     /// Remove the `ExecutionGraph` for the given job ID from cache
@@ -658,7 +709,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
 
     async fn clean_up_job_data(
         state: Arc<dyn StateBackendClient>,
-        active_job_cache: ExecutionGraphCache,
+        active_job_cache: ActiveJobCache,
         failed: bool,
         job_id: String,
     ) -> Result<()> {

Reply via email to