This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new 0e6f0463 Add REST API to generate DOT graph for individual query stage 
(#310)
0e6f0463 is described below

commit 0e6f0463807f18c469993d54b1c4783367ebd788
Author: Andy Grove <[email protected]>
AuthorDate: Wed Oct 5 12:07:56 2022 -0600

    Add REST API to generate DOT graph for individual query stage (#310)
---
 ballista/rust/scheduler/src/api/handlers.rs        |  60 +++++----
 ballista/rust/scheduler/src/api/mod.rs             |   8 ++
 ballista/rust/scheduler/src/flight_sql.rs          |   3 +-
 .../rust/scheduler/src/scheduler_server/event.rs   |   2 +-
 .../rust/scheduler/src/scheduler_server/grpc.rs    |   8 +-
 .../rust/scheduler/src/scheduler_server/mod.rs     |  12 +-
 .../src/scheduler_server/query_stage_scheduler.rs  |   2 +-
 .../rust/scheduler/src/state/execution_graph.rs    |  32 ++---
 .../src/state/execution_graph/execution_stage.rs   |  24 ++++
 .../scheduler/src/state/execution_graph_dot.rs     | 146 +++++++++++----------
 ballista/rust/scheduler/src/state/mod.rs           |  37 +-----
 ballista/rust/scheduler/src/state/task_manager.rs  |   6 +-
 12 files changed, 183 insertions(+), 157 deletions(-)

diff --git a/ballista/rust/scheduler/src/api/handlers.rs 
b/ballista/rust/scheduler/src/api/handlers.rs
index 807d27a5..d5010d6a 100644
--- a/ballista/rust/scheduler/src/api/handlers.rs
+++ b/ballista/rust/scheduler/src/api/handlers.rs
@@ -145,7 +145,7 @@ pub(crate) async fn get_jobs<T: AsLogicalPlan, U: 
AsExecutionPlan>(
                 ((job.completed_stages as f32 / job.num_stages as f32) * 
100_f32) as u8;
             JobResponse {
                 job_id: job.job_id.to_string(),
-                job_name: job.job_name.to_owned().unwrap_or_default(),
+                job_name: job.job_name.to_string(),
                 job_status,
                 num_stages: job.num_stages,
                 completed_stages: job.completed_stages,
@@ -167,35 +167,27 @@ pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: 
AsExecutionPlan>(
     data_server: SchedulerServer<T, U>,
     job_id: String,
 ) -> Result<impl warp::Reply, Rejection> {
-    let maybe_graph = data_server
+    if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| warp::reject())?;
-
-    match maybe_graph {
-        Some(graph) => Ok(warp::reply::json(&QueryStagesResponse {
+        .map_err(|_| warp::reject())?
+    {
+        Ok(warp::reply::json(&QueryStagesResponse {
             stages: graph
                 .stages()
                 .iter()
                 .map(|(id, stage)| {
                     let mut summary = QueryStageSummary {
                         stage_id: id.to_string(),
-                        stage_status: "".to_string(),
+                        stage_status: stage.variant_name().to_string(),
                         input_rows: 0,
                         output_rows: 0,
                         elapsed_compute: "".to_string(),
                     };
                     match stage {
-                        ExecutionStage::UnResolved(_) => {
-                            summary.stage_status = "Unresolved".to_string();
-                        }
-                        ExecutionStage::Resolved(_) => {
-                            summary.stage_status = "Resolved".to_string();
-                        }
                         ExecutionStage::Running(running_stage) => {
-                            summary.stage_status = "Running".to_string();
                             summary.input_rows = running_stage
                                 .stage_metrics
                                 .as_ref()
@@ -213,7 +205,6 @@ pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: 
AsExecutionPlan>(
                                 .unwrap_or_default();
                         }
                         ExecutionStage::Successful(completed_stage) => {
-                            summary.stage_status = "Completed".to_string();
                             summary.input_rows = get_combined_count(
                                 &completed_stage.stage_metrics,
                                 "input_rows",
@@ -225,15 +216,14 @@ pub(crate) async fn get_query_stages<T: AsLogicalPlan, U: 
AsExecutionPlan>(
                             summary.elapsed_compute =
                                 
get_elapsed_compute_nanos(&completed_stage.stage_metrics);
                         }
-                        ExecutionStage::Failed(_) => {
-                            summary.stage_status = "Failed".to_string();
-                        }
+                        _ => {}
                     }
                     summary
                 })
                 .collect(),
-        })),
-        _ => Ok(warp::reply::json(&QueryStagesResponse { stages: vec![] })),
+        }))
+    } else {
+        Ok(warp::reply::json(&QueryStagesResponse { stages: vec![] }))
     }
 }
 
@@ -273,16 +263,36 @@ pub(crate) async fn get_job_dot_graph<T: AsLogicalPlan, 
U: AsExecutionPlan>(
     data_server: SchedulerServer<T, U>,
     job_id: String,
 ) -> Result<String, Rejection> {
-    let graph = data_server
+    if let Some(graph) = data_server
         .state
         .task_manager
         .get_job_execution_graph(&job_id)
         .await
-        .map_err(|_| warp::reject())?;
+        .map_err(|_| warp::reject())?
+    {
+        ExecutionGraphDot::generate(graph).map_err(|_| warp::reject())
+    } else {
+        Ok("Not Found".to_string())
+    }
+}
 
-    match graph {
-        Some(x) => ExecutionGraphDot::generate(x).map_err(|_| warp::reject()),
-        _ => Ok("Not Found".to_string()),
+/// Generate a dot graph for the specified job id and query stage and return 
as plain text
+pub(crate) async fn get_query_stage_dot_graph<T: AsLogicalPlan, U: 
AsExecutionPlan>(
+    data_server: SchedulerServer<T, U>,
+    job_id: String,
+    stage_id: usize,
+) -> Result<String, Rejection> {
+    if let Some(graph) = data_server
+        .state
+        .task_manager
+        .get_job_execution_graph(&job_id)
+        .await
+        .map_err(|_| warp::reject())?
+    {
+        ExecutionGraphDot::generate_for_query_stage(graph, stage_id)
+            .map_err(|_| warp::reject())
+    } else {
+        Ok("Not Found".to_string())
     }
 }
 
diff --git a/ballista/rust/scheduler/src/api/mod.rs 
b/ballista/rust/scheduler/src/api/mod.rs
index 17eec052..3f00b767 100644
--- a/ballista/rust/scheduler/src/api/mod.rs
+++ b/ballista/rust/scheduler/src/api/mod.rs
@@ -105,6 +105,13 @@ pub fn get_routes<T: AsLogicalPlan + Clone, U: 'static + 
AsExecutionPlan>(
         .and(with_data_server(scheduler_server.clone()))
         .and_then(|job_id, data_server| 
handlers::get_job_dot_graph(data_server, job_id));
 
+    let route_query_stage_dot =
+        warp::path!("api" / "job" / String / "stage" / usize / "dot")
+            .and(with_data_server(scheduler_server.clone()))
+            .and_then(|job_id, stage_id, data_server| {
+                handlers::get_query_stage_dot_graph(data_server, job_id, 
stage_id)
+            });
+
     let route_job_dot_svg = warp::path!("api" / "job" / String / "dot_svg")
         .and(with_data_server(scheduler_server))
         .and_then(|job_id, data_server| 
handlers::get_job_svg_graph(data_server, job_id));
@@ -114,6 +121,7 @@ pub fn get_routes<T: AsLogicalPlan + Clone, U: 'static + 
AsExecutionPlan>(
         .or(route_jobs)
         .or(route_query_stages)
         .or(route_job_dot)
+        .or(route_query_stage_dot)
         .or(route_job_dot_svg);
     routes.boxed()
 }
diff --git a/ballista/rust/scheduler/src/flight_sql.rs 
b/ballista/rust/scheduler/src/flight_sql.rs
index 3d5222c2..7caca105 100644
--- a/ballista/rust/scheduler/src/flight_sql.rs
+++ b/ballista/rust/scheduler/src/flight_sql.rs
@@ -295,8 +295,9 @@ impl FlightSqlServiceImpl {
         plan: &LogicalPlan,
     ) -> Result<String, Status> {
         let job_id = self.server.state.task_manager.generate_job_id();
+        let job_name = format!("Flight SQL job {}", job_id);
         self.server
-            .submit_job(&job_id, None, ctx, plan)
+            .submit_job(&job_id, &job_name, ctx, plan)
             .await
             .map_err(|e| {
                 let msg =
diff --git a/ballista/rust/scheduler/src/scheduler_server/event.rs 
b/ballista/rust/scheduler/src/scheduler_server/event.rs
index c0e7e16b..d01c0516 100644
--- a/ballista/rust/scheduler/src/scheduler_server/event.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/event.rs
@@ -28,7 +28,7 @@ use std::sync::Arc;
 pub enum QueryStageSchedulerEvent {
     JobQueued {
         job_id: String,
-        job_name: Option<String>,
+        job_name: String,
         session_ctx: Arc<SessionContext>,
         plan: Box<LogicalPlan>,
     },
diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs 
b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
index 05d76b2f..1fbb86a5 100644
--- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
@@ -424,9 +424,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
             debug!("Received plan for execution: {:?}", plan);
 
             let job_id = self.state.task_manager.generate_job_id();
-            let job_name = config.settings().get(BALLISTA_JOB_NAME);
+            let job_name = config
+                .settings()
+                .get(BALLISTA_JOB_NAME)
+                .cloned()
+                .unwrap_or_default();
 
-            self.submit_job(&job_id, job_name.cloned(), session_ctx, &plan)
+            self.submit_job(&job_id, &job_name, session_ctx, &plan)
                 .await
                 .map_err(|e| {
                     let msg =
diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs 
b/ballista/rust/scheduler/src/scheduler_server/mod.rs
index 3250a4b2..6cfbf070 100644
--- a/ballista/rust/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs
@@ -141,7 +141,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
     pub(crate) async fn submit_job(
         &self,
         job_id: &str,
-        job_name: Option<String>,
+        job_name: &str,
         ctx: Arc<SessionContext>,
         plan: &LogicalPlan,
     ) -> Result<()> {
@@ -149,7 +149,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
             .get_sender()?
             .post_event(QueryStageSchedulerEvent::JobQueued {
                 job_id: job_id.to_owned(),
-                job_name,
+                job_name: job_name.to_owned(),
                 session_ctx: ctx,
                 plan: Box::new(plan.clone()),
             })
@@ -339,7 +339,7 @@ mod test {
         // Submit job
         scheduler
             .state
-            .submit_job(job_id, None, ctx, &plan)
+            .submit_job(job_id, "", ctx, &plan)
             .await
             .expect("submitting plan");
 
@@ -444,7 +444,7 @@ mod test {
 
         let job_id = "job";
 
-        scheduler.state.submit_job(job_id, None, ctx, &plan).await?;
+        scheduler.state.submit_job(job_id, "", ctx, &plan).await?;
 
         // Complete tasks that are offered through scheduler events
         loop {
@@ -593,7 +593,7 @@ mod test {
 
         let job_id = "job";
 
-        scheduler.state.submit_job(job_id, None, ctx, &plan).await?;
+        scheduler.state.submit_job(job_id, "", ctx, &plan).await?;
 
         let available_tasks = scheduler
             .state
@@ -729,7 +729,7 @@ mod test {
         let job_id = "job";
 
         // This should fail when we try and create the physical plan
-        scheduler.submit_job(job_id, None, ctx, &plan).await?;
+        scheduler.submit_job(job_id, "", ctx, &plan).await?;
 
         let scheduler = scheduler.clone();
 
diff --git 
a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs 
b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
index 64e31d03..96de88a8 100644
--- a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -80,7 +80,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
                 let state = self.state.clone();
                 tokio::spawn(async move {
                     let event = if let Err(e) = state
-                        .submit_job(&job_id, job_name, session_ctx, &plan)
+                        .submit_job(&job_id, &job_name, session_ctx, &plan)
                         .await
                     {
                         let msg = format!("Error planning job {}: {:?}", 
job_id, e);
diff --git a/ballista/rust/scheduler/src/state/execution_graph.rs 
b/ballista/rust/scheduler/src/state/execution_graph.rs
index 8cb206d3..1524b64c 100644
--- a/ballista/rust/scheduler/src/state/execution_graph.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph.rs
@@ -104,8 +104,8 @@ pub struct ExecutionGraph {
     scheduler_id: String,
     /// ID for this job
     job_id: String,
-    /// Optional job name
-    job_name: Option<String>,
+    /// Job name, can be empty string
+    job_name: String,
     /// Session ID for this job
     session_id: String,
     /// Status of this job
@@ -136,7 +136,7 @@ impl ExecutionGraph {
     pub fn new(
         scheduler_id: &str,
         job_id: &str,
-        job_name: Option<String>,
+        job_name: &str,
         session_id: &str,
         plan: Arc<dyn ExecutionPlan>,
     ) -> Result<Self> {
@@ -152,7 +152,7 @@ impl ExecutionGraph {
         Ok(Self {
             scheduler_id: scheduler_id.to_string(),
             job_id: job_id.to_string(),
-            job_name,
+            job_name: job_name.to_string(),
             session_id: session_id.to_string(),
             status: JobStatus {
                 status: Some(job_status::Status::Queued(QueuedJob {})),
@@ -169,8 +169,8 @@ impl ExecutionGraph {
         self.job_id.as_str()
     }
 
-    pub fn job_name(&self) -> Option<&String> {
-        self.job_name.as_ref()
+    pub fn job_name(&self) -> &str {
+        self.job_name.as_str()
     }
 
     pub fn session_id(&self) -> &str {
@@ -1284,11 +1284,7 @@ impl ExecutionGraph {
         Ok(ExecutionGraph {
             scheduler_id: proto.scheduler_id,
             job_id: proto.job_id,
-            job_name: if proto.job_name.is_empty() {
-                None
-            } else {
-                Some(proto.job_name)
-            },
+            job_name: proto.job_name,
             session_id: proto.session_id,
             status: proto.status.ok_or_else(|| {
                 BallistaError::Internal(
@@ -1364,7 +1360,7 @@ impl ExecutionGraph {
 
         Ok(protobuf::ExecutionGraph {
             job_id: graph.job_id,
-            job_name: graph.job_name.unwrap_or_default(),
+            job_name: graph.job_name,
             session_id: graph.session_id,
             status: Some(graph.status),
             stages,
@@ -2748,7 +2744,7 @@ mod test {
 
         println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
 
-        ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap()
+        ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap()
     }
 
     async fn test_two_aggregations_plan(partition: usize) -> ExecutionGraph {
@@ -2776,7 +2772,7 @@ mod test {
 
         println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
 
-        ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap()
+        ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap()
     }
 
     async fn test_coalesce_plan(partition: usize) -> ExecutionGraph {
@@ -2799,7 +2795,7 @@ mod test {
 
         let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
 
-        ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap()
+        ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap()
     }
 
     async fn test_join_plan(partition: usize) -> ExecutionGraph {
@@ -2841,7 +2837,7 @@ mod test {
         println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
 
         let graph =
-            ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap();
+            ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap();
 
         println!("{:?}", graph);
 
@@ -2866,7 +2862,7 @@ mod test {
         println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
 
         let graph =
-            ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap();
+            ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap();
 
         println!("{:?}", graph);
 
@@ -2891,7 +2887,7 @@ mod test {
         println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
 
         let graph =
-            ExecutionGraph::new("localhost:50050", "job", None, "session", 
plan).unwrap();
+            ExecutionGraph::new("localhost:50050", "job", "", "session", 
plan).unwrap();
 
         println!("{:?}", graph);
 
diff --git 
a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs 
b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
index b7ef0087..582b34f1 100644
--- a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
@@ -73,6 +73,30 @@ impl Debug for ExecutionStage {
     }
 }
 
+impl ExecutionStage {
+    /// Get the name of the variant
+    pub(crate) fn variant_name(&self) -> &str {
+        match self {
+            ExecutionStage::UnResolved(_) => "Unresolved",
+            ExecutionStage::Resolved(_) => "Resolved",
+            ExecutionStage::Running(_) => "Running",
+            ExecutionStage::Successful(_) => "Successful",
+            ExecutionStage::Failed(_) => "Failed",
+        }
+    }
+
+    /// Get the query plan for this query stage
+    pub(crate) fn plan(&self) -> &dyn ExecutionPlan {
+        match self {
+            ExecutionStage::UnResolved(stage) => stage.plan.as_ref(),
+            ExecutionStage::Resolved(stage) => stage.plan.as_ref(),
+            ExecutionStage::Running(stage) => stage.plan.as_ref(),
+            ExecutionStage::Successful(stage) => stage.plan.as_ref(),
+            ExecutionStage::Failed(stage) => stage.plan.as_ref(),
+        }
+    }
+}
+
 /// For a stage whose input stages are not all completed, we say it's a 
unresolved stage
 #[derive(Clone)]
 pub(crate) struct UnresolvedStage {
diff --git a/ballista/rust/scheduler/src/state/execution_graph_dot.rs 
b/ballista/rust/scheduler/src/state/execution_graph_dot.rs
index 963b9d97..d46e93f8 100644
--- a/ballista/rust/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph_dot.rs
@@ -17,7 +17,7 @@
 
 //! Utilities for producing dot diagrams from execution graphs
 
-use crate::state::execution_graph::{ExecutionGraph, ExecutionStage};
+use crate::state::execution_graph::ExecutionGraph;
 use ballista_core::execution_plans::{
     ShuffleReaderExec, ShuffleWriterExec, UnresolvedShuffleExec,
 };
@@ -56,6 +56,23 @@ impl ExecutionGraphDot {
         dot._generate()
     }
 
+    /// Create a DOT graph for one query stage from the provided ExecutionGraph
+    pub fn generate_for_query_stage(
+        graph: Arc<ExecutionGraph>,
+        stage_id: usize,
+    ) -> Result<String, fmt::Error> {
+        if let Some(stage) = graph.stages().get(&stage_id) {
+            let mut dot = String::new();
+            writeln!(&mut dot, "digraph G {{")?;
+            let stage_name = format!("stage_{}", stage_id);
+            write_stage_plan(&mut dot, &stage_name, stage.plan(), 0)?;
+            writeln!(&mut dot, "}}")?;
+            Ok(dot)
+        } else {
+            Err(fmt::Error::default())
+        }
+    }
+
     fn _generate(&mut self) -> Result<String, fmt::Error> {
         // sort the stages by key for deterministic output for tests
         let stages = self.graph.stages();
@@ -74,53 +91,13 @@ impl ExecutionGraphDot {
             let stage = stages.get(id).unwrap(); // safe unwrap
             let stage_name = format!("stage_{}", id);
             writeln!(&mut dot, "\tsubgraph cluster{} {{", cluster)?;
-            match stage {
-                ExecutionStage::UnResolved(stage) => {
-                    writeln!(&mut dot, "\t\tlabel = \"Stage {} 
[UnResolved]\";", id)?;
-                    stage_meta.push(write_stage_plan(
-                        &mut dot,
-                        &stage_name,
-                        &stage.plan,
-                        0,
-                    )?);
-                }
-                ExecutionStage::Resolved(stage) => {
-                    writeln!(&mut dot, "\t\tlabel = \"Stage {} [Resolved]\";", 
id)?;
-                    stage_meta.push(write_stage_plan(
-                        &mut dot,
-                        &stage_name,
-                        &stage.plan,
-                        0,
-                    )?);
-                }
-                ExecutionStage::Running(stage) => {
-                    writeln!(&mut dot, "\t\tlabel = \"Stage {} [Running]\";", 
id)?;
-                    stage_meta.push(write_stage_plan(
-                        &mut dot,
-                        &stage_name,
-                        &stage.plan,
-                        0,
-                    )?);
-                }
-                ExecutionStage::Successful(stage) => {
-                    writeln!(&mut dot, "\t\tlabel = \"Stage {} 
[Completed]\";", id)?;
-                    stage_meta.push(write_stage_plan(
-                        &mut dot,
-                        &stage_name,
-                        &stage.plan,
-                        0,
-                    )?);
-                }
-                ExecutionStage::Failed(stage) => {
-                    writeln!(&mut dot, "\t\tlabel = \"Stage {} [FAILED]\";", 
id)?;
-                    stage_meta.push(write_stage_plan(
-                        &mut dot,
-                        &stage_name,
-                        &stage.plan,
-                        0,
-                    )?);
-                }
-            }
+            writeln!(
+                &mut dot,
+                "\t\tlabel = \"Stage {} [{}]\";",
+                id,
+                stage.variant_name()
+            )?;
+            stage_meta.push(write_stage_plan(&mut dot, &stage_name, 
stage.plan(), 0)?);
             cluster += 1;
             writeln!(&mut dot, "\t}}")?; // end of subgraph
         }
@@ -151,7 +128,7 @@ impl ExecutionGraphDot {
 fn write_stage_plan(
     f: &mut String,
     prefix: &str,
-    plan: &Arc<dyn ExecutionPlan>,
+    plan: &dyn ExecutionPlan,
     i: usize,
 ) -> Result<StagePlanState, fmt::Error> {
     let mut state = StagePlanState {
@@ -164,7 +141,7 @@ fn write_stage_plan(
 fn write_plan_recursive(
     f: &mut String,
     prefix: &str,
-    plan: &Arc<dyn ExecutionPlan>,
+    plan: &dyn ExecutionPlan,
     i: usize,
     state: &mut StagePlanState,
 ) -> Result<(), fmt::Error> {
@@ -210,7 +187,7 @@ fn write_plan_recursive(
     }
 
     for (j, child) in plan.children().into_iter().enumerate() {
-        write_plan_recursive(f, &node_name, &child, j, state)?;
+        write_plan_recursive(f, &node_name, child.as_ref(), j, state)?;
         // write link from child to parent
         writeln!(f, "\t\t{}_{} -> {}", node_name, j, node_name)?;
     }
@@ -256,7 +233,7 @@ fn sanitize(str: &str, max_len: Option<usize>) -> String {
     sanitized
 }
 
-fn get_operator_name(plan: &Arc<dyn ExecutionPlan>) -> String {
+fn get_operator_name(plan: &dyn ExecutionPlan) -> String {
     if let Some(exec) = plan.as_any().downcast_ref::<FilterExec>() {
         format!("Filter: {}", exec.predicate())
     } else if let Some(exec) = plan.as_any().downcast_ref::<ProjectionExec>() {
@@ -452,21 +429,7 @@ mod tests {
 
     #[tokio::test]
     async fn dot() -> Result<()> {
-        let ctx =
-            
SessionContext::with_config(SessionConfig::new().with_target_partitions(48));
-        let schema =
-            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, 
false)]));
-        let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
-        ctx.register_table("foo", table.clone())?;
-        ctx.register_table("bar", table.clone())?;
-        ctx.register_table("baz", table)?;
-        let df = ctx
-            .sql("SELECT * FROM foo JOIN bar ON foo.a = bar.a JOIN baz on 
bar.a = baz.a")
-            .await?;
-        let plan = df.to_logical_plan()?;
-        let plan = ctx.create_physical_plan(&plan).await?;
-        let graph =
-            ExecutionGraph::new("scheduler_id", "job_id", None, "session_id", 
plan)?;
+        let graph = test_graph().await?;
         let dot = ExecutionGraphDot::generate(Arc::new(graph))
             .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?;
 
@@ -484,7 +447,7 @@ mod tests {
                stage_2_0_0 -> stage_2_0
        }
        subgraph cluster2 {
-               label = "Stage 3 [UnResolved]";
+               label = "Stage 3 [Unresolved]";
                stage_3_0 [shape=box, label="ShuffleWriter [48 partitions]"]
                stage_3_0_0 [shape=box, label="CoalesceBatches 
[batchSize=4096]"]
                stage_3_0_0_0 [shape=box, label="HashJoin
@@ -508,7 +471,7 @@ filter_expr="]
                stage_4_0_0 -> stage_4_0
        }
        subgraph cluster4 {
-               label = "Stage 5 [UnResolved]";
+               label = "Stage 5 [Unresolved]";
                stage_5_0 [shape=box, label="ShuffleWriter [48 partitions]"]
                stage_5_0_0 [shape=box, label="Projection: a@0, a@1, a@2"]
                stage_5_0_0_0 [shape=box, label="CoalesceBatches 
[batchSize=4096]"]
@@ -536,4 +499,49 @@ filter_expr="]
         assert_eq!(expected, &dot);
         Ok(())
     }
+
+    #[tokio::test]
+    async fn query_stage() -> Result<()> {
+        let graph = test_graph().await?;
+        let dot = ExecutionGraphDot::generate_for_query_stage(Arc::new(graph), 
3)
+            .map_err(|e| BallistaError::Internal(format!("{:?}", e)))?;
+
+        let expected = r#"digraph G {
+               stage_3_0 [shape=box, label="ShuffleWriter [48 partitions]"]
+               stage_3_0_0 [shape=box, label="CoalesceBatches 
[batchSize=4096]"]
+               stage_3_0_0_0 [shape=box, label="HashJoin
+join_expr=a@0 = a@0
+filter_expr="]
+               stage_3_0_0_0_0 [shape=box, label="CoalesceBatches 
[batchSize=4096]"]
+               stage_3_0_0_0_0_0 [shape=box, label="UnresolvedShuffleExec 
[stage_id=1]"]
+               stage_3_0_0_0_0_0 -> stage_3_0_0_0_0
+               stage_3_0_0_0_0 -> stage_3_0_0_0
+               stage_3_0_0_0_1 [shape=box, label="CoalesceBatches 
[batchSize=4096]"]
+               stage_3_0_0_0_1_0 [shape=box, label="UnresolvedShuffleExec 
[stage_id=2]"]
+               stage_3_0_0_0_1_0 -> stage_3_0_0_0_1
+               stage_3_0_0_0_1 -> stage_3_0_0_0
+               stage_3_0_0_0 -> stage_3_0_0
+               stage_3_0_0 -> stage_3_0
+}
+"#;
+        assert_eq!(expected, &dot);
+        Ok(())
+    }
+
+    async fn test_graph() -> Result<ExecutionGraph> {
+        let ctx =
+            
SessionContext::with_config(SessionConfig::new().with_target_partitions(48));
+        let schema =
+            Arc::new(Schema::new(vec![Field::new("a", DataType::UInt32, 
false)]));
+        let table = Arc::new(MemTable::try_new(schema.clone(), vec![])?);
+        ctx.register_table("foo", table.clone())?;
+        ctx.register_table("bar", table.clone())?;
+        ctx.register_table("baz", table)?;
+        let df = ctx
+            .sql("SELECT * FROM foo JOIN bar ON foo.a = bar.a JOIN baz on 
bar.a = baz.a")
+            .await?;
+        let plan = df.to_logical_plan()?;
+        let plan = ctx.create_physical_plan(&plan).await?;
+        ExecutionGraph::new("scheduler_id", "job_id", "job_name", 
"session_id", plan)
+    }
 }
diff --git a/ballista/rust/scheduler/src/state/mod.rs 
b/ballista/rust/scheduler/src/state/mod.rs
index 94d69f6d..2168ab64 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -251,7 +251,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerState<T,
     pub(crate) async fn submit_job(
         &self,
         job_id: &str,
-        job_name: Option<String>,
+        job_name: &str,
         session_ctx: Arc<SessionContext>,
         plan: &LogicalPlan,
     ) -> Result<()> {
@@ -358,39 +358,19 @@ mod test {
         // Create 4 jobs so we have four pending tasks
         state
             .task_manager
-            .submit_job(
-                "job-1",
-                None,
-                session_ctx.session_id().as_str(),
-                plan.clone(),
-            )
+            .submit_job("job-1", "", session_ctx.session_id().as_str(), 
plan.clone())
             .await?;
         state
             .task_manager
-            .submit_job(
-                "job-2",
-                None,
-                session_ctx.session_id().as_str(),
-                plan.clone(),
-            )
+            .submit_job("job-2", "", session_ctx.session_id().as_str(), 
plan.clone())
             .await?;
         state
             .task_manager
-            .submit_job(
-                "job-3",
-                None,
-                session_ctx.session_id().as_str(),
-                plan.clone(),
-            )
+            .submit_job("job-3", "", session_ctx.session_id().as_str(), 
plan.clone())
             .await?;
         state
             .task_manager
-            .submit_job(
-                "job-4",
-                None,
-                session_ctx.session_id().as_str(),
-                plan.clone(),
-            )
+            .submit_job("job-4", "", session_ctx.session_id().as_str(), 
plan.clone())
             .await?;
 
         let executors = test_executors(1, 4);
@@ -435,12 +415,7 @@ mod test {
         // Create a job
         state
             .task_manager
-            .submit_job(
-                "job-1",
-                None,
-                session_ctx.session_id().as_str(),
-                plan.clone(),
-            )
+            .submit_job("job-1", "", session_ctx.session_id().as_str(), 
plan.clone())
             .await?;
 
         let executors = test_executors(1, 4);
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs 
b/ballista/rust/scheduler/src/state/task_manager.rs
index b7d6fd67..810d0f37 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -95,7 +95,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
     pub async fn submit_job(
         &self,
         job_id: &str,
-        job_name: Option<String>,
+        job_name: &str,
         session_id: &str,
         plan: Arc<dyn ExecutionPlan>,
     ) -> Result<()> {
@@ -142,7 +142,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
             }
             jobs.push(JobOverview {
                 job_id: job_id.clone(),
-                job_name: graph.job_name().cloned(),
+                job_name: graph.job_name().to_string(),
                 status: graph.status(),
                 num_stages: graph.stage_count(),
                 completed_stages,
@@ -587,7 +587,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
 
 pub struct JobOverview {
     pub job_id: String,
-    pub job_name: Option<String>,
+    pub job_name: String,
     pub status: JobStatus,
     pub num_stages: usize,
     pub completed_stages: usize,

Reply via email to