This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 85031c42 Add ExecutionEngine abstraction (#687)
85031c42 is described below

commit 85031c42208238c5f373aae915f141c8e91db6b4
Author: Andy Grove <[email protected]>
AuthorDate: Thu Mar 2 16:23:48 2023 -0700

    Add ExecutionEngine abstraction (#687)
---
 ballista/executor/src/bin/main.rs         |   1 +
 ballista/executor/src/execution_engine.rs | 114 ++++++++++++++++++++++++++++++
 ballista/executor/src/execution_loop.rs   |  15 ++--
 ballista/executor/src/executor.rs         |  74 ++++++++-----------
 ballista/executor/src/executor_process.rs |   5 ++
 ballista/executor/src/executor_server.rs  |  19 ++---
 ballista/executor/src/lib.rs              |   1 +
 ballista/executor/src/metrics/mod.rs      |  14 ++--
 ballista/executor/src/standalone.rs       |   1 +
 9 files changed, 174 insertions(+), 70 deletions(-)

diff --git a/ballista/executor/src/bin/main.rs 
b/ballista/executor/src/bin/main.rs
index 706a8b97..b5765165 100644
--- a/ballista/executor/src/bin/main.rs
+++ b/ballista/executor/src/bin/main.rs
@@ -77,6 +77,7 @@ async fn main() -> Result<()> {
         print_thread_info: opt.print_thread_info,
         job_data_ttl_seconds: opt.job_data_ttl_seconds,
         job_data_clean_up_interval_seconds: 
opt.job_data_clean_up_interval_seconds,
+        execution_engine: None,
     };
 
     start_executor_process(config).await
diff --git a/ballista/executor/src/execution_engine.rs 
b/ballista/executor/src/execution_engine.rs
new file mode 100644
index 00000000..d62176a9
--- /dev/null
+++ b/ballista/executor/src/execution_engine.rs
@@ -0,0 +1,114 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use async_trait::async_trait;
+use ballista_core::execution_plans::ShuffleWriterExec;
+use ballista_core::serde::protobuf::ShuffleWritePartition;
+use ballista_core::utils;
+use datafusion::error::{DataFusionError, Result};
+use datafusion::execution::context::TaskContext;
+use datafusion::physical_plan::metrics::MetricsSet;
+use datafusion::physical_plan::ExecutionPlan;
+use std::fmt::Debug;
+use std::sync::Arc;
+
+/// Execution engine extension point
+
+pub trait ExecutionEngine: Sync + Send {
+    fn create_query_stage_exec(
+        &self,
+        job_id: String,
+        stage_id: usize,
+        plan: Arc<dyn ExecutionPlan>,
+        work_dir: &str,
+    ) -> Result<Arc<dyn QueryStageExecutor>>;
+}
+
+/// QueryStageExecutor executes a section of a query plan that has consistent 
partitioning and
+/// can be executed as one unit with each partition being executed in 
parallel. The output of each
+/// partition is re-partitioned and streamed to disk in Arrow IPC format. 
Future stages of the query
+/// will use the ShuffleReaderExec to read these results.
+#[async_trait]
+pub trait QueryStageExecutor: Sync + Send + Debug {
+    async fn execute_query_stage(
+        &self,
+        input_partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<Vec<ShuffleWritePartition>>;
+
+    fn collect_plan_metrics(&self) -> Vec<MetricsSet>;
+}
+
+pub struct DefaultExecutionEngine {}
+
+impl ExecutionEngine for DefaultExecutionEngine {
+    fn create_query_stage_exec(
+        &self,
+        job_id: String,
+        stage_id: usize,
+        plan: Arc<dyn ExecutionPlan>,
+        work_dir: &str,
+    ) -> Result<Arc<dyn QueryStageExecutor>> {
+        // the query plan created by the scheduler always starts with a 
ShuffleWriterExec
+        let exec = if let Some(shuffle_writer) =
+            plan.as_any().downcast_ref::<ShuffleWriterExec>()
+        {
+            // recreate the shuffle writer with the correct working directory
+            ShuffleWriterExec::try_new(
+                job_id,
+                stage_id,
+                plan.children()[0].clone(),
+                work_dir.to_string(),
+                shuffle_writer.shuffle_output_partitioning().cloned(),
+            )
+        } else {
+            Err(DataFusionError::Internal(
+                "Plan passed to new_query_stage_exec is not a 
ShuffleWriterExec"
+                    .to_string(),
+            ))
+        }?;
+        Ok(Arc::new(DefaultQueryStageExec::new(exec)))
+    }
+}
+
+#[derive(Debug)]
+pub struct DefaultQueryStageExec {
+    shuffle_writer: ShuffleWriterExec,
+}
+
+impl DefaultQueryStageExec {
+    pub fn new(shuffle_writer: ShuffleWriterExec) -> Self {
+        Self { shuffle_writer }
+    }
+}
+
+#[async_trait]
+impl QueryStageExecutor for DefaultQueryStageExec {
+    async fn execute_query_stage(
+        &self,
+        input_partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<Vec<ShuffleWritePartition>> {
+        self.shuffle_writer
+            .execute_shuffle_write(input_partition, context)
+            .await
+    }
+
+    fn collect_plan_metrics(&self) -> Vec<MetricsSet> {
+        utils::collect_plan_metrics(self.shuffle_writer.children()[0].as_ref())
+    }
+}
diff --git a/ballista/executor/src/execution_loop.rs 
b/ballista/executor/src/execution_loop.rs
index 8efce745..8cf6a4da 100644
--- a/ballista/executor/src/execution_loop.rs
+++ b/ballista/executor/src/execution_loop.rs
@@ -29,7 +29,6 @@ use crate::{as_task_status, TaskExecutionTimes};
 use ballista_core::error::BallistaError;
 use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId};
 use ballista_core::serde::BallistaCodec;
-use ballista_core::utils::collect_plan_metrics;
 use datafusion::execution::context::TaskContext;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use 
datafusion_proto::physical_plan::from_proto::parse_protobuf_hash_partitioning;
@@ -209,8 +208,12 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionP
         plan.schema().as_ref(),
     )?;
 
-    let shuffle_writer_plan =
-        executor.new_shuffle_writer(job_id.clone(), stage_id as usize, plan)?;
+    let query_stage_exec = executor.execution_engine.create_query_stage_exec(
+        job_id.clone(),
+        stage_id as usize,
+        plan,
+        &executor.work_dir,
+    )?;
     dedicated_executor.spawn(async move {
         use std::panic::AssertUnwindSafe;
         let part = PartitionId {
@@ -219,10 +222,10 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionP
             partition_id: partition_id as usize,
         };
 
-        let execution_result = match 
AssertUnwindSafe(executor.execute_shuffle_write(
+        let execution_result = match 
AssertUnwindSafe(executor.execute_query_stage(
             task_id as usize,
             part.clone(),
-            shuffle_writer_plan.clone(),
+            query_stage_exec.clone(),
             task_context,
             shuffle_output_partitioning,
         ))
@@ -240,7 +243,7 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U: 
'static + AsExecutionP
         info!("Done with task {}", task_identity);
         debug!("Statistics: {:?}", execution_result);
 
-        let plan_metrics = collect_plan_metrics(shuffle_writer_plan.as_ref());
+        let plan_metrics = query_stage_exec.collect_plan_metrics();
         let operator_metrics = plan_metrics
             .into_iter()
             .map(|m| m.try_into())
diff --git a/ballista/executor/src/executor.rs 
b/ballista/executor/src/executor.rs
index 867b3ba8..4e60f6ab 100644
--- a/ballista/executor/src/executor.rs
+++ b/ballista/executor/src/executor.rs
@@ -17,28 +17,26 @@
 
 //! Ballista executor logic
 
-use dashmap::DashMap;
-use std::collections::HashMap;
-use std::future::Future;
-use std::pin::Pin;
-use std::sync::Arc;
-use std::task::{Context, Poll};
-
+use crate::execution_engine::DefaultExecutionEngine;
+use crate::execution_engine::ExecutionEngine;
+use crate::execution_engine::QueryStageExecutor;
 use crate::metrics::ExecutorMetricsCollector;
 use ballista_core::error::BallistaError;
-use ballista_core::execution_plans::ShuffleWriterExec;
 use ballista_core::serde::protobuf;
 use ballista_core::serde::protobuf::ExecutorRegistration;
-use datafusion::error::DataFusionError;
+use ballista_core::serde::scheduler::PartitionId;
+use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
 use datafusion::execution::runtime_env::RuntimeEnv;
-
 use datafusion::physical_plan::udaf::AggregateUDF;
 use datafusion::physical_plan::udf::ScalarUDF;
-use datafusion::physical_plan::{ExecutionPlan, Partitioning};
+use datafusion::physical_plan::Partitioning;
 use futures::future::AbortHandle;
-
-use ballista_core::serde::scheduler::PartitionId;
+use std::collections::HashMap;
+use std::future::Future;
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
 
 pub struct TasksDrainedFuture(pub Arc<Executor>);
 
@@ -82,6 +80,10 @@ pub struct Executor {
 
     /// Handles to abort executing tasks
     abort_handles: AbortHandles,
+
+    /// Execution engine that the executor will delegate to
+    /// for executing query stages
+    pub(crate) execution_engine: Arc<dyn ExecutionEngine>,
 }
 
 impl Executor {
@@ -92,6 +94,7 @@ impl Executor {
         runtime: Arc<RuntimeEnv>,
         metrics_collector: Arc<dyn ExecutorMetricsCollector>,
         concurrent_tasks: usize,
+        execution_engine: Option<Arc<dyn ExecutionEngine>>,
     ) -> Self {
         Self {
             metadata,
@@ -103,6 +106,8 @@ impl Executor {
             metrics_collector,
             concurrent_tasks,
             abort_handles: Default::default(),
+            execution_engine: execution_engine
+                .unwrap_or_else(|| Arc::new(DefaultExecutionEngine {})),
         }
     }
 }
@@ -111,16 +116,16 @@ impl Executor {
     /// Execute one partition of a query stage and persist the result to disk 
in IPC format. On
     /// success, return a RecordBatch containing metadata about the results, 
including path
     /// and statistics.
-    pub async fn execute_shuffle_write(
+    pub async fn execute_query_stage(
         &self,
         task_id: usize,
         partition: PartitionId,
-        shuffle_writer: Arc<ShuffleWriterExec>,
+        query_stage_exec: Arc<dyn QueryStageExecutor>,
         task_ctx: Arc<TaskContext>,
         _shuffle_output_partitioning: Option<Partitioning>,
     ) -> Result<Vec<protobuf::ShuffleWritePartition>, BallistaError> {
         let (task, abort_handle) = futures::future::abortable(
-            shuffle_writer.execute_shuffle_write(partition.partition_id, 
task_ctx),
+            query_stage_exec.execute_query_stage(partition.partition_id, 
task_ctx),
         );
 
         self.abort_handles
@@ -134,39 +139,12 @@ impl Executor {
             &partition.job_id,
             partition.stage_id,
             partition.partition_id,
-            shuffle_writer,
+            query_stage_exec,
         );
 
         Ok(partitions)
     }
 
-    /// Recreate the shuffle writer with the correct working directory.
-    pub fn new_shuffle_writer(
-        &self,
-        job_id: String,
-        stage_id: usize,
-        plan: Arc<dyn ExecutionPlan>,
-    ) -> Result<Arc<ShuffleWriterExec>, BallistaError> {
-        let exec = if let Some(shuffle_writer) =
-            plan.as_any().downcast_ref::<ShuffleWriterExec>()
-        {
-            // recreate the shuffle writer with the correct working directory
-            ShuffleWriterExec::try_new(
-                job_id,
-                stage_id,
-                plan.children()[0].clone(),
-                self.work_dir.clone(),
-                shuffle_writer.shuffle_output_partitioning().cloned(),
-            )
-        } else {
-            Err(DataFusionError::Internal(
-                "Plan passed to execute_shuffle_write is not a 
ShuffleWriterExec"
-                    .to_string(),
-            ))
-        }?;
-        Ok(Arc::new(exec))
-    }
-
     pub async fn cancel_task(
         &self,
         task_id: usize,
@@ -208,6 +186,7 @@ mod test {
     use ballista_core::serde::protobuf::ExecutorRegistration;
     use datafusion::execution::context::TaskContext;
 
+    use crate::execution_engine::DefaultQueryStageExec;
     use ballista_core::serde::scheduler::PartitionId;
     use datafusion::error::DataFusionError;
     use datafusion::physical_expr::PhysicalSortExpr;
@@ -307,6 +286,8 @@ mod test {
         )
         .expect("creating shuffle writer");
 
+        let query_stage_exec = DefaultQueryStageExec::new(shuffle_write);
+
         let executor_registration = ExecutorRegistration {
             id: "executor".to_string(),
             port: 0,
@@ -323,6 +304,7 @@ mod test {
             ctx.runtime_env(),
             Arc::new(LoggingMetricsCollector {}),
             2,
+            None,
         );
 
         let (sender, receiver) = tokio::sync::oneshot::channel();
@@ -336,10 +318,10 @@ mod test {
                 partition_id: 0,
             };
             let task_result = executor_clone
-                .execute_shuffle_write(
+                .execute_query_stage(
                     1,
                     part,
-                    Arc::new(shuffle_write),
+                    Arc::new(query_stage_exec),
                     ctx.task_ctx(),
                     None,
                 )
diff --git a/ballista/executor/src/executor_process.rs 
b/ballista/executor/src/executor_process.rs
index 6db3de06..6aea9b6e 100644
--- a/ballista/executor/src/executor_process.rs
+++ b/ballista/executor/src/executor_process.rs
@@ -55,6 +55,7 @@ use ballista_core::utils::{
 };
 use ballista_core::BALLISTA_VERSION;
 
+use crate::execution_engine::ExecutionEngine;
 use crate::executor::{Executor, TasksDrainedFuture};
 use crate::executor_server::TERMINATING;
 use crate::flight_service::BallistaFlightService;
@@ -82,6 +83,9 @@ pub struct ExecutorProcessConfig {
     pub log_rotation_policy: LogRotationPolicy,
     pub job_data_ttl_seconds: u64,
     pub job_data_clean_up_interval_seconds: u64,
+    /// Optional execution engine to use to execute physical plans, will 
default to
+    /// DataFusion if none is provided.
+    pub execution_engine: Option<Arc<dyn ExecutionEngine>>,
 }
 
 pub async fn start_executor_process(opt: ExecutorProcessConfig) -> Result<()> {
@@ -181,6 +185,7 @@ pub async fn start_executor_process(opt: 
ExecutorProcessConfig) -> Result<()> {
         runtime,
         metrics_collector,
         concurrent_tasks,
+        opt.execution_engine,
     ));
 
     let connect_timeout = opt.scheduler_connect_timeout_seconds as u64;
diff --git a/ballista/executor/src/executor_server.rs 
b/ballista/executor/src/executor_server.rs
index 89f2eef5..4468be98 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -42,9 +42,7 @@ use ballista_core::serde::protobuf::{
 use ballista_core::serde::scheduler::PartitionId;
 use ballista_core::serde::scheduler::TaskDefinition;
 use ballista_core::serde::BallistaCodec;
-use ballista_core::utils::{
-    collect_plan_metrics, create_grpc_client_connection, create_grpc_server,
-};
+use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
 use dashmap::DashMap;
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::ExecutionPlan;
@@ -352,9 +350,12 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         let stage_id = task.stage_id;
         let stage_attempt_num = task.stage_attempt_num;
         let partition_id = task.partition_id;
-        let shuffle_writer_plan =
-            self.executor
-                .new_shuffle_writer(job_id.clone(), stage_id, plan)?;
+        let query_stage_exec = 
self.executor.execution_engine.create_query_stage_exec(
+            job_id.clone(),
+            stage_id,
+            plan,
+            &self.executor.work_dir,
+        )?;
 
         let part = PartitionId {
             job_id: job_id.clone(),
@@ -366,10 +367,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
 
         let execution_result = self
             .executor
-            .execute_shuffle_write(
+            .execute_query_stage(
                 task_id,
                 part.clone(),
-                shuffle_writer_plan.clone(),
+                query_stage_exec.clone(),
                 task_context,
                 shuffle_output_partitioning,
             )
@@ -377,7 +378,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         info!("Done with task {}", task_identity);
         debug!("Statistics: {:?}", execution_result);
 
-        let plan_metrics = collect_plan_metrics(shuffle_writer_plan.as_ref());
+        let plan_metrics = query_stage_exec.collect_plan_metrics();
         let operator_metrics = plan_metrics
             .into_iter()
             .map(|m| m.try_into())
diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs
index 6f1ce33d..beb9faac 100644
--- a/ballista/executor/src/lib.rs
+++ b/ballista/executor/src/lib.rs
@@ -18,6 +18,7 @@
 #![doc = include_str!("../README.md")]
 
 pub mod collect;
+pub mod execution_engine;
 pub mod execution_loop;
 pub mod executor;
 pub mod executor_process;
diff --git a/ballista/executor/src/metrics/mod.rs 
b/ballista/executor/src/metrics/mod.rs
index e4daaac1..9a0f58fa 100644
--- a/ballista/executor/src/metrics/mod.rs
+++ b/ballista/executor/src/metrics/mod.rs
@@ -15,8 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use ballista_core::execution_plans::ShuffleWriterExec;
-use datafusion::physical_plan::display::DisplayableExecutionPlan;
+use crate::execution_engine::QueryStageExecutor;
 use log::info;
 use std::sync::Arc;
 
@@ -32,7 +31,7 @@ pub trait ExecutorMetricsCollector: Send + Sync {
         job_id: &str,
         stage_id: usize,
         partition: usize,
-        plan: Arc<ShuffleWriterExec>,
+        plan: Arc<dyn QueryStageExecutor>,
     );
 }
 
@@ -47,14 +46,11 @@ impl ExecutorMetricsCollector for LoggingMetricsCollector {
         job_id: &str,
         stage_id: usize,
         partition: usize,
-        plan: Arc<ShuffleWriterExec>,
+        plan: Arc<dyn QueryStageExecutor>,
     ) {
         info!(
-            "=== [{}/{}/{}] Physical plan with metrics ===\n{}\n",
-            job_id,
-            stage_id,
-            partition,
-            DisplayableExecutionPlan::with_metrics(plan.as_ref()).indent()
+            "=== [{}/{}/{}] Physical plan with metrics ===\n{:?}\n",
+            job_id, stage_id, partition, plan
         );
     }
 }
diff --git a/ballista/executor/src/standalone.rs 
b/ballista/executor/src/standalone.rs
index d38a725f..692ad46d 100644
--- a/ballista/executor/src/standalone.rs
+++ b/ballista/executor/src/standalone.rs
@@ -83,6 +83,7 @@ pub async fn new_standalone_executor<
         Arc::new(RuntimeEnv::new(config).unwrap()),
         Arc::new(LoggingMetricsCollector::default()),
         concurrent_tasks,
+        None,
     ));
 
     let service = BallistaFlightService::new();

Reply via email to