This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new ad6ca18d Decouple `ExecutionGraph` and `DistributedPlanner` (#1221)
ad6ca18d is described below

commit ad6ca18da8d5ca581bcd04c5d4f24ce27efdadf5
Author: Marko Milenković <[email protected]>
AuthorDate: Thu Apr 10 09:21:48 2025 +0100

    Decouple `ExecutionGraph` and `DistributedPlanner` (#1221)
    
    making possible to provide a different `DistributedPlanner` to
    `ExecutionGraph`.
---
 ballista/scheduler/src/planner.rs                  | 34 +++++++++++++++-------
 ballista/scheduler/src/state/execution_graph.rs    |  5 ++--
 .../scheduler/src/state/execution_graph_dot.rs     |  5 ++++
 ballista/scheduler/src/state/task_manager.rs       |  3 ++
 ballista/scheduler/src/test_utils.rs               | 19 ++++++++----
 5 files changed, 47 insertions(+), 19 deletions(-)

diff --git a/ballista/scheduler/src/planner.rs 
b/ballista/scheduler/src/planner.rs
index 7d8a19bd..61e903bc 100644
--- a/ballista/scheduler/src/planner.rs
+++ b/ballista/scheduler/src/planner.rs
@@ -36,27 +36,39 @@ use log::{debug, info};
 
 type PartialQueryStageResult = (Arc<dyn ExecutionPlan>, 
Vec<Arc<ShuffleWriterExec>>);
 
-pub struct DistributedPlanner {
+/// [DistributedPlanner] breaks out plan into vector of stages
+pub trait DistributedPlanner {
+    /// Returns a vector of ExecutionPlans, where the root node is a 
[ShuffleWriterExec].
+    /// Plans that depend on the input of other plans will have leaf nodes of 
type [UnresolvedShuffleExec].
+    /// A [ShuffleWriterExec] is created whenever the partitioning changes.
+    fn plan_query_stages<'a>(
+        &'a mut self,
+        job_id: &'a str,
+        execution_plan: Arc<dyn ExecutionPlan>,
+    ) -> Result<Vec<Arc<ShuffleWriterExec>>>;
+}
+/// Default implementation of [DistributedPlanner]
+pub struct DefaultDistributedPlanner {
     next_stage_id: usize,
 }
 
-impl DistributedPlanner {
+impl DefaultDistributedPlanner {
     pub fn new() -> Self {
         Self { next_stage_id: 0 }
     }
 }
 
-impl Default for DistributedPlanner {
+impl Default for DefaultDistributedPlanner {
     fn default() -> Self {
         Self::new()
     }
 }
 
-impl DistributedPlanner {
+impl DistributedPlanner for DefaultDistributedPlanner {
     /// Returns a vector of ExecutionPlans, where the root node is a 
[ShuffleWriterExec].
     /// Plans that depend on the input of other plans will have leaf nodes of 
type [UnresolvedShuffleExec].
     /// A [ShuffleWriterExec] is created whenever the partitioning changes.
-    pub fn plan_query_stages<'a>(
+    fn plan_query_stages<'a>(
         &'a mut self,
         job_id: &'a str,
         execution_plan: Arc<dyn ExecutionPlan>,
@@ -72,7 +84,9 @@ impl DistributedPlanner {
         )?);
         Ok(stages)
     }
+}
 
+impl DefaultDistributedPlanner {
     /// Returns a potentially modified version of the input execution_plan 
along with the resulting query stages.
     /// This function is needed because the input execution_plan might need to 
be modified, but it might not hold a
     /// complete query stage (its parent might also belong to the same stage)
@@ -292,7 +306,7 @@ fn create_shuffle_writer(
 
 #[cfg(test)]
 mod test {
-    use crate::planner::DistributedPlanner;
+    use crate::planner::{DefaultDistributedPlanner, DistributedPlanner};
     use crate::test_utils::datafusion_test_context;
     use ballista_core::error::BallistaError;
     use ballista_core::execution_plans::{ShuffleWriterExec, 
UnresolvedShuffleExec};
@@ -346,7 +360,7 @@ mod test {
         let plan = session_state.optimize(&plan)?;
         let plan = session_state.create_physical_plan(&plan).await?;
 
-        let mut planner = DistributedPlanner::new();
+        let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
         let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
         for (i, stage) in stages.iter().enumerate() {
@@ -460,7 +474,7 @@ order by
         let plan = session_state.optimize(&plan)?;
         let plan = session_state.create_physical_plan(&plan).await?;
 
-        let mut planner = DistributedPlanner::new();
+        let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
         let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
         for (i, stage) in stages.iter().enumerate() {
@@ -628,7 +642,7 @@ order by
         let plan = session_state.optimize(&plan)?;
         let plan = session_state.create_physical_plan(&plan).await?;
 
-        let mut planner = DistributedPlanner::new();
+        let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
         let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
         for (i, stage) in stages.iter().enumerate() {
@@ -737,7 +751,7 @@ order by
         let plan = session_state.optimize(&plan)?;
         let plan = session_state.create_physical_plan(&plan).await?;
 
-        let mut planner = DistributedPlanner::new();
+        let mut planner = DefaultDistributedPlanner::new();
         let job_uuid = Uuid::new_v4();
         let stages = planner.plan_query_stages(&job_uuid.to_string(), plan)?;
 
diff --git a/ballista/scheduler/src/state/execution_graph.rs 
b/ballista/scheduler/src/state/execution_graph.rs
index f3e6bf76..a24e0f3c 100644
--- a/ballista/scheduler/src/state/execution_graph.rs
+++ b/ballista/scheduler/src/state/execution_graph.rs
@@ -140,6 +140,7 @@ pub struct RunningTaskInfo {
 }
 
 impl ExecutionGraph {
+    #[allow(clippy::too_many_arguments)]
     pub fn new(
         scheduler_id: &str,
         job_id: &str,
@@ -148,11 +149,9 @@ impl ExecutionGraph {
         plan: Arc<dyn ExecutionPlan>,
         queued_at: u64,
         session_config: Arc<SessionConfig>,
+        planner: &mut dyn DistributedPlanner,
     ) -> Result<Self> {
-        let mut planner = DistributedPlanner::new();
-
         let output_partitions = 
plan.properties().output_partitioning().partition_count();
-
         let shuffle_stages = planner.plan_query_stages(job_id, plan)?;
 
         let builder = ExecutionStageBuilder::new(session_config.clone());
diff --git a/ballista/scheduler/src/state/execution_graph_dot.rs 
b/ballista/scheduler/src/state/execution_graph_dot.rs
index 68a2ebdf..f6944af1 100644
--- a/ballista/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/scheduler/src/state/execution_graph_dot.rs
@@ -415,6 +415,7 @@ fn get_file_scan(scan: &FileScanConfig) -> String {
 
 #[cfg(test)]
 mod tests {
+    use crate::planner::DefaultDistributedPlanner;
     use crate::state::execution_graph::ExecutionGraph;
     use crate::state::execution_graph_dot::ExecutionGraphDot;
     use ballista_core::error::{BallistaError, Result};
@@ -645,6 +646,7 @@ filter_expr="]
             .await?;
         let plan = df.into_optimized_plan()?;
         let plan = ctx.state().create_physical_plan(&plan).await?;
+        let mut planner = DefaultDistributedPlanner::new();
         ExecutionGraph::new(
             "scheduler_id",
             "job_id",
@@ -653,6 +655,7 @@ filter_expr="]
             plan,
             0,
             Arc::new(SessionConfig::new_with_ballista()),
+            &mut planner,
         )
     }
 
@@ -679,6 +682,7 @@ filter_expr="]
             .await?;
         let plan = df.into_optimized_plan()?;
         let plan = ctx.state().create_physical_plan(&plan).await?;
+        let mut planner = DefaultDistributedPlanner::new();
         ExecutionGraph::new(
             "scheduler_id",
             "job_id",
@@ -687,6 +691,7 @@ filter_expr="]
             plan,
             0,
             Arc::new(SessionConfig::new_with_ballista()),
+            &mut planner,
         )
     }
 }
diff --git a/ballista/scheduler/src/state/task_manager.rs 
b/ballista/scheduler/src/state/task_manager.rs
index 53a352bd..b3e6e064 100644
--- a/ballista/scheduler/src/state/task_manager.rs
+++ b/ballista/scheduler/src/state/task_manager.rs
@@ -15,6 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::planner::DefaultDistributedPlanner;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 
 use crate::state::execution_graph::{
@@ -207,6 +208,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         queued_at: u64,
         session_config: Arc<SessionConfig>,
     ) -> Result<()> {
+        let mut planner = DefaultDistributedPlanner::new();
         let mut graph = ExecutionGraph::new(
             &self.scheduler_id,
             job_id,
@@ -215,6 +217,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
             plan,
             queued_at,
             session_config,
+            &mut planner,
         )?;
         info!("Submitting execution graph: {:?}", graph);
 
diff --git a/ballista/scheduler/src/test_utils.rs 
b/ballista/scheduler/src/test_utils.rs
index 8e4565a4..9f8c2ecd 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -28,6 +28,7 @@ use async_trait::async_trait;
 
 use crate::config::SchedulerConfig;
 use crate::metrics::SchedulerMetricsCollector;
+use crate::planner::DefaultDistributedPlanner;
 use crate::scheduler_server::{timestamp_millis, SchedulerServer};
 
 use crate::state::executor_manager::ExecutorManager;
@@ -861,7 +862,7 @@ pub async fn test_aggregation_plan_with_job_id(
         "{}",
         DisplayableExecutionPlan::new(plan.as_ref()).indent(false)
     );
-
+    let mut planner = DefaultDistributedPlanner::new();
     ExecutionGraph::new(
         "localhost:50050",
         job_id,
@@ -870,6 +871,7 @@ pub async fn test_aggregation_plan_with_job_id(
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap()
 }
@@ -906,7 +908,7 @@ pub async fn test_two_aggregations_plan(partition: usize) 
-> ExecutionGraph {
         "{}",
         DisplayableExecutionPlan::new(plan.as_ref()).indent(false)
     );
-
+    let mut planner = DefaultDistributedPlanner::new();
     ExecutionGraph::new(
         "localhost:50050",
         "job",
@@ -915,6 +917,7 @@ pub async fn test_two_aggregations_plan(partition: usize) 
-> ExecutionGraph {
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap()
 }
@@ -943,7 +946,7 @@ pub async fn test_coalesce_plan(partition: usize) -> 
ExecutionGraph {
         .create_physical_plan(&optimized_plan)
         .await
         .unwrap();
-
+    let mut planner = DefaultDistributedPlanner::new();
     ExecutionGraph::new(
         "localhost:50050",
         "job",
@@ -952,6 +955,7 @@ pub async fn test_coalesce_plan(partition: usize) -> 
ExecutionGraph {
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap()
 }
@@ -1001,7 +1005,7 @@ pub async fn test_join_plan(partition: usize) -> 
ExecutionGraph {
         "{}",
         DisplayableExecutionPlan::new(plan.as_ref()).indent(false)
     );
-
+    let mut planner = DefaultDistributedPlanner::new();
     let graph = ExecutionGraph::new(
         "localhost:50050",
         "job",
@@ -1010,6 +1014,7 @@ pub async fn test_join_plan(partition: usize) -> 
ExecutionGraph {
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap();
 
@@ -1041,7 +1046,7 @@ pub async fn test_union_all_plan(partition: usize) -> 
ExecutionGraph {
         "{}",
         DisplayableExecutionPlan::new(plan.as_ref()).indent(false)
     );
-
+    let mut planner = DefaultDistributedPlanner::new();
     let graph = ExecutionGraph::new(
         "localhost:50050",
         "job",
@@ -1050,6 +1055,7 @@ pub async fn test_union_all_plan(partition: usize) -> 
ExecutionGraph {
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap();
 
@@ -1081,7 +1087,7 @@ pub async fn test_union_plan(partition: usize) -> 
ExecutionGraph {
         "{}",
         DisplayableExecutionPlan::new(plan.as_ref()).indent(false)
     );
-
+    let mut planner = DefaultDistributedPlanner::new();
     let graph = ExecutionGraph::new(
         "localhost:50050",
         "job",
@@ -1090,6 +1096,7 @@ pub async fn test_union_plan(partition: usize) -> 
ExecutionGraph {
         plan,
         0,
         Arc::new(SessionConfig::new_with_ballista()),
+        &mut planner,
     )
     .unwrap();
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to