This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new 23f8cad7 Executor lost handling (#184)
23f8cad7 is described below

commit 23f8cad74de7d3e08b202eb8c086083586d4511c
Author: mingmwang <[email protected]>
AuthorDate: Fri Sep 9 08:14:15 2022 +0800

    Executor lost handling (#184)
    
    * Executor Lost Handling
    
    * add dead executor sets to ExecutorManager
    
    * add UT for reset_stages in ExecutionGraph
    
    * add ExecutorLost to QueryStageSchedulerEvent
    
    * Fix rollback ResolvedStage
    
    * Tiny fix
    
    * Resolve review comments, add more UT
---
 ballista/rust/core/proto/ballista.proto            |  40 +-
 ballista/rust/core/src/event_loop.rs               |   1 +
 .../core/src/execution_plans/shuffle_reader.rs     |   2 +-
 .../rust/core/src/serde/scheduler/from_proto.rs    |  22 +-
 ballista/rust/core/src/serde/scheduler/mod.rs      |   7 -
 ballista/rust/core/src/serde/scheduler/to_proto.rs |  18 +-
 ballista/rust/executor/src/executor_server.rs      |  25 +-
 ballista/rust/executor/src/main.rs                 |  38 +-
 ballista/rust/scheduler/src/planner.rs             |  28 ++
 .../rust/scheduler/src/scheduler_server/event.rs   |   7 +-
 .../rust/scheduler/src/scheduler_server/grpc.rs    | 242 ++++++++++-
 .../rust/scheduler/src/scheduler_server/mod.rs     | 111 ++++-
 .../src/scheduler_server/query_stage_scheduler.rs  |  17 +-
 .../rust/scheduler/src/state/execution_graph.rs    | 446 ++++++++++++++++++---
 .../src/state/execution_graph/execution_stage.rs   | 255 +++++++++---
 .../rust/scheduler/src/state/executor_manager.rs   | 128 +++++-
 ballista/rust/scheduler/src/state/task_manager.rs  |  30 ++
 17 files changed, 1206 insertions(+), 211 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto 
b/ballista/rust/core/proto/ballista.proto
index ee03d556..7afe04e6 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -440,7 +440,8 @@ message ResolvedStage {
   uint32 partitions = 2;
   PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
-  bytes plan = 5;
+  repeated  GraphStageInput inputs = 5;
+  bytes plan = 6;
 }
 
 message CompletedStage {
@@ -448,9 +449,10 @@ message CompletedStage {
   uint32 partitions = 2;
   PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
-  bytes plan = 5;
-  repeated TaskStatus task_statuses = 6;
-  repeated OperatorMetricsSet stage_metrics = 7;
+  repeated  GraphStageInput inputs = 5;
+  bytes plan = 6;
+  repeated TaskStatus task_statuses = 7;
+  repeated OperatorMetricsSet stage_metrics = 8;
 }
 
 message FailedStage {
@@ -599,11 +601,8 @@ message ExecutorHeartbeat {
   string executor_id = 1;
   // Unix epoch-based timestamp in seconds
   uint64 timestamp = 2;
-  ExecutorState state = 3;
-}
-
-message ExecutorState {
-  repeated ExecutorMetric metrics = 1;
+  repeated ExecutorMetric metrics = 3;
+  ExecutorStatus status = 4;
 }
 
 message ExecutorMetric {
@@ -613,6 +612,14 @@ message ExecutorMetric {
   }
 }
 
+message ExecutorStatus {
+  oneof status {
+    string active = 1;
+    string dead = 2;
+    string unknown = 3;
+  }
+}
+
 message ExecutorSpecification {
   repeated ExecutorResource resources = 1;
 }
@@ -706,7 +713,8 @@ message RegisterExecutorResult {
 
 message HeartBeatParams {
   string executor_id = 1;
-  ExecutorState state = 2;
+  repeated ExecutorMetric metrics = 2;
+  ExecutorStatus status = 3;
 }
 
 message HeartBeatResult {
@@ -724,6 +732,15 @@ message StopExecutorParams {
 message StopExecutorResult {
 }
 
+message ExecutorStoppedParams {
+  string executor_id = 1;
+  // stop reason
+  string reason = 2;
+}
+
+message ExecutorStoppedResult {
+}
+
 message UpdateTaskStatusParams {
   string executor_id = 1;
   // All tasks must be reported until they reach the failed or completed state
@@ -842,6 +859,9 @@ service SchedulerGrpc {
 
   rpc GetJobStatus (GetJobStatusParams) returns (GetJobStatusResult) {}
 
+  // Used by Executor to tell Scheduler it is stopped.
+  rpc ExecutorStopped (ExecutorStoppedParams) returns (ExecutorStoppedResult) 
{}
+
   rpc CancelJob (CancelJobParams) returns (CancelJobResult) {}
 }
 
diff --git a/ballista/rust/core/src/event_loop.rs 
b/ballista/rust/core/src/event_loop.rs
index 74ee4ebf..a803bf89 100644
--- a/ballista/rust/core/src/event_loop.rs
+++ b/ballista/rust/core/src/event_loop.rs
@@ -123,6 +123,7 @@ impl<E: Send + 'static> EventLoop<E> {
     }
 }
 
+#[derive(Clone)]
 pub struct EventSender<E> {
     tx_event: mpsc::Sender<E>,
 }
diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs 
b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
index c69d120b..5c0664ef 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
@@ -43,7 +43,7 @@ use log::info;
 #[derive(Debug, Clone)]
 pub struct ShuffleReaderExec {
     /// Each partition of a shuffle can read data from multiple locations
-    pub(crate) partition: Vec<Vec<PartitionLocation>>,
+    pub partition: Vec<Vec<PartitionLocation>>,
     pub(crate) schema: SchemaRef,
     /// Execution metrics
     metrics: ExecutionPlanMetricsSet,
diff --git a/ballista/rust/core/src/serde/scheduler/from_proto.rs 
b/ballista/rust/core/src/serde/scheduler/from_proto.rs
index 42a0c67a..728ad46b 100644
--- a/ballista/rust/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/from_proto.rs
@@ -29,8 +29,8 @@ use crate::serde::protobuf;
 use crate::serde::protobuf::action::ActionType;
 use crate::serde::protobuf::{operator_metric, NamedCount, NamedGauge, 
NamedTime};
 use crate::serde::scheduler::{
-    Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, 
ExecutorState,
-    PartitionId, PartitionLocation, PartitionStats,
+    Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
+    PartitionLocation, PartitionStats,
 };
 
 impl TryInto<Action> for protobuf::Action {
@@ -263,21 +263,3 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
         ret
     }
 }
-
-#[allow(clippy::from_over_into)]
-impl Into<ExecutorState> for protobuf::ExecutorState {
-    fn into(self) -> ExecutorState {
-        let mut ret = ExecutorState {
-            available_memory_size: u64::MAX,
-        };
-        for metric in self.metrics {
-            if let Some(protobuf::executor_metric::Metric::AvailableMemory(
-                available_memory_size,
-            )) = metric.metric
-            {
-                ret.available_memory_size = available_memory_size
-            }
-        }
-        ret
-    }
-}
diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs 
b/ballista/rust/core/src/serde/scheduler/mod.rs
index 6f5a61d5..1eee7d72 100644
--- a/ballista/rust/core/src/serde/scheduler/mod.rs
+++ b/ballista/rust/core/src/serde/scheduler/mod.rs
@@ -98,13 +98,6 @@ pub struct ExecutorDataChange {
     pub task_slots: i32,
 }
 
-/// The internal state of an executor, like cpu usage, memory usage, etc
-#[derive(Debug, Clone, Copy, Serialize)]
-pub struct ExecutorState {
-    // in bytes
-    pub available_memory_size: u64,
-}
-
 /// Summary of executed partition
 #[derive(Debug, Copy, Clone, Default)]
 pub struct PartitionStats {
diff --git a/ballista/rust/core/src/serde/scheduler/to_proto.rs 
b/ballista/rust/core/src/serde/scheduler/to_proto.rs
index 7517408b..10c841e9 100644
--- a/ballista/rust/core/src/serde/scheduler/to_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/to_proto.rs
@@ -24,8 +24,8 @@ use crate::serde::protobuf::action::ActionType;
 
 use crate::serde::protobuf::{operator_metric, NamedCount, NamedGauge, 
NamedTime};
 use crate::serde::scheduler::{
-    Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, 
ExecutorState,
-    PartitionId, PartitionLocation, PartitionStats,
+    Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
+    PartitionLocation, PartitionStats,
 };
 use datafusion::physical_plan::Partitioning;
 
@@ -234,17 +234,3 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
         }
     }
 }
-
-#[allow(clippy::from_over_into)]
-impl Into<protobuf::ExecutorState> for ExecutorState {
-    fn into(self) -> protobuf::ExecutorState {
-        protobuf::ExecutorState {
-            metrics: vec![protobuf::executor_metric::Metric::AvailableMemory(
-                self.available_memory_size,
-            )]
-            .into_iter()
-            .map(|m| protobuf::ExecutorMetric { metric: Some(m) })
-            .collect(),
-        }
-    }
-}
diff --git a/ballista/rust/executor/src/executor_server.rs 
b/ballista/rust/executor/src/executor_server.rs
index f00ca7f0..6d0719ff 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -35,11 +35,11 @@ use ballista_core::serde::protobuf::executor_grpc_server::{
 use ballista_core::serde::protobuf::executor_registration::OptionalHost;
 use ballista_core::serde::protobuf::scheduler_grpc_client::SchedulerGrpcClient;
 use ballista_core::serde::protobuf::{
-    CancelTasksParams, CancelTasksResult, HeartBeatParams, LaunchTaskParams,
-    LaunchTaskResult, RegisterExecutorParams, StopExecutorParams, 
StopExecutorResult,
-    TaskDefinition, TaskStatus, UpdateTaskStatusParams,
+    executor_metric, executor_status, CancelTasksParams, CancelTasksResult,
+    ExecutorMetric, ExecutorStatus, HeartBeatParams, LaunchTaskParams, 
LaunchTaskResult,
+    RegisterExecutorParams, StopExecutorParams, StopExecutorResult, 
TaskDefinition,
+    TaskStatus, UpdateTaskStatusParams,
 };
-use ballista_core::serde::scheduler::ExecutorState;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use ballista_core::utils::{
     collect_plan_metrics, create_grpc_client_connection, create_grpc_server,
@@ -249,7 +249,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
     async fn heartbeat(&self) {
         let heartbeat_params = HeartBeatParams {
             executor_id: self.executor.metadata.id.clone(),
-            state: Some(self.get_executor_state().into()),
+            metrics: self.get_executor_metrics(),
+            status: Some(ExecutorStatus {
+                status: Some(executor_status::Status::Active("".to_string())),
+            }),
         };
         let mut scheduler = self.scheduler_to_register.clone();
         match scheduler
@@ -385,11 +388,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> ExecutorServer<T,
         Ok(())
     }
 
-    // TODO with real state
-    fn get_executor_state(&self) -> ExecutorState {
-        ExecutorState {
-            available_memory_size: u64::MAX,
-        }
+    // TODO populate with real metrics
+    fn get_executor_metrics(&self) -> Vec<ExecutorMetric> {
+        let available_memory = ExecutorMetric {
+            metric: Some(executor_metric::Metric::AvailableMemory(u64::MAX)),
+        };
+        let executor_metrics = vec![available_memory];
+        executor_metrics
     }
 }
 
diff --git a/ballista/rust/executor/src/main.rs 
b/ballista/rust/executor/src/main.rs
index 010aaae9..cbf25ebb 100644
--- a/ballista/rust/executor/src/main.rs
+++ b/ballista/rust/executor/src/main.rs
@@ -37,7 +37,7 @@ use ballista_core::config::TaskSchedulingPolicy;
 use ballista_core::error::BallistaError;
 use ballista_core::serde::protobuf::{
     executor_registration, scheduler_grpc_client::SchedulerGrpcClient,
-    ExecutorRegistration, PhysicalPlanNode,
+    ExecutorRegistration, ExecutorStoppedParams, PhysicalPlanNode,
 };
 use ballista_core::serde::scheduler::ExecutorSpecification;
 use ballista_core::serde::BallistaCodec;
@@ -136,8 +136,10 @@ async fn main() -> Result<()> {
     info!("work_dir: {}", work_dir);
     info!("concurrent_tasks: {}", opt.concurrent_tasks);
 
+    // assign this executor an unique ID
+    let executor_id = Uuid::new_v4().to_string();
     let executor_meta = ExecutorRegistration {
-        id: Uuid::new_v4().to_string(), // assign this executor a unique ID
+        id: executor_id.clone(),
         optional_host: external_host
             .clone()
             .map(executor_registration::OptionalHost::Host),
@@ -170,7 +172,7 @@ async fn main() -> Result<()> {
         .await
         .context("Could not connect to scheduler")?;
 
-    let scheduler = SchedulerGrpcClient::new(connection);
+    let mut scheduler = SchedulerGrpcClient::new(connection);
 
     let default_codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> =
         BallistaCodec::default();
@@ -222,7 +224,7 @@ async fn main() -> Result<()> {
             service_handlers.push(
                 //If there is executor registration error during startup, 
return the error and stop early.
                 executor_server::startup(
-                    scheduler,
+                    scheduler.clone(),
                     executor.clone(),
                     default_codec,
                     stop_send,
@@ -233,7 +235,7 @@ async fn main() -> Result<()> {
         }
         _ => {
             service_handlers.push(tokio::spawn(execution_loop::poll_loop(
-                scheduler,
+                scheduler.clone(),
                 executor.clone(),
                 default_codec,
             )));
@@ -247,15 +249,33 @@ async fn main() -> Result<()> {
     // Concurrently run the service checking and listen for the `shutdown` 
signal and wait for the stop request coming.
     // The check_services runs until an error is encountered, so under normal 
circumstances, this `select!` statement runs
     // until the `shutdown` signal is received or a stop request is coming.
-    tokio::select! {
+    let (notify_scheduler, stop_reason) = tokio::select! {
         service_val = check_services(&mut service_handlers) => {
-             info!("services stopped with reason {:?}", service_val);
+            let msg = format!("executor services stopped with reason {:?}", 
service_val);
+            info!("{:?}", msg);
+            (true, msg)
         },
         _ = signal::ctrl_c() => {
              // sometimes OS can not log ??
-             info!("received ctrl-c event.");
+            let msg = "executor received ctrl-c event.".to_string();
+             info!("{:?}", msg);
+            (true, msg)
         },
-        _ = stop_recv.recv() => {},
+        _ = stop_recv.recv() => {
+            (false, "".to_string())
+        },
+    };
+
+    if notify_scheduler {
+        if let Err(error) = scheduler
+            .executor_stopped(ExecutorStoppedParams {
+                executor_id,
+                reason: stop_reason,
+            })
+            .await
+        {
+            error!("ExecutorStopped grpc failed: {:?}", error);
+        }
     }
 
     // Extract the `shutdown_complete` receiver and transmitter
diff --git a/ballista/rust/scheduler/src/planner.rs 
b/ballista/rust/scheduler/src/planner.rs
index 9c393bac..bda9dd77 100644
--- a/ballista/rust/scheduler/src/planner.rs
+++ b/ballista/rust/scheduler/src/planner.rs
@@ -246,6 +246,34 @@ pub fn remove_unresolved_shuffles(
     Ok(with_new_children_if_necessary(stage, new_children)?)
 }
 
+/// Rollback the ShuffleReaderExec to UnresolvedShuffleExec.
+/// Used when the input stages are finished but some partitions are missing 
due to executor lost.
+/// The entire stage need to be rolled back and rescheduled.
+pub fn rollback_resolved_shuffles(
+    stage: Arc<dyn ExecutionPlan>,
+) -> Result<Arc<dyn ExecutionPlan>> {
+    let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
+    for child in stage.children() {
+        if let Some(shuffle_reader) = 
child.as_any().downcast_ref::<ShuffleReaderExec>() {
+            let partition_locations = &shuffle_reader.partition;
+            let output_partition_count = partition_locations.len();
+            let input_partition_count = partition_locations[0].len();
+            let stage_id = partition_locations[0][0].partition_id.stage_id;
+
+            let unresolved_shuffle = Arc::new(UnresolvedShuffleExec::new(
+                stage_id,
+                shuffle_reader.schema(),
+                input_partition_count,
+                output_partition_count,
+            ));
+            new_children.push(unresolved_shuffle);
+        } else {
+            new_children.push(rollback_resolved_shuffles(child)?);
+        }
+    }
+    Ok(with_new_children_if_necessary(stage, new_children)?)
+}
+
 fn create_shuffle_writer(
     job_id: &str,
     stage_id: usize,
diff --git a/ballista/rust/scheduler/src/scheduler_server/event.rs 
b/ballista/rust/scheduler/src/scheduler_server/event.rs
index 10793ccc..ad462944 100644
--- a/ballista/rust/scheduler/src/scheduler_server/event.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/event.rs
@@ -23,12 +23,6 @@ use ballista_core::serde::protobuf::TaskStatus;
 use datafusion::prelude::SessionContext;
 use std::sync::Arc;
 
-#[derive(Clone, Debug)]
-pub enum SchedulerServerEvent {
-    /// Offer a list of executor reservations (representing executor task 
slots available for scheduling)
-    Offer(Vec<ExecutorReservation>),
-}
-
 #[derive(Clone)]
 pub enum QueryStageSchedulerEvent {
     JobQueued {
@@ -45,4 +39,5 @@ pub enum QueryStageSchedulerEvent {
     JobUpdated(String),
     TaskUpdating(String, Vec<TaskStatus>),
     ReservationOffering(Vec<ExecutorReservation>),
+    ExecutorLost(String, Option<String>),
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs 
b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
index 0937158b..09f92cbd 100644
--- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
@@ -21,11 +21,12 @@ use 
ballista_core::serde::protobuf::execute_query_params::{OptionalSessionId, Qu
 use ballista_core::serde::protobuf::executor_registration::OptionalHost;
 use ballista_core::serde::protobuf::scheduler_grpc_server::SchedulerGrpc;
 use ballista_core::serde::protobuf::{
-    CancelJobParams, CancelJobResult, ExecuteQueryParams, ExecuteQueryResult,
-    ExecutorHeartbeat, GetFileMetadataParams, GetFileMetadataResult, 
GetJobStatusParams,
-    GetJobStatusResult, HeartBeatParams, HeartBeatResult, PollWorkParams, 
PollWorkResult,
-    RegisterExecutorParams, RegisterExecutorResult, UpdateTaskStatusParams,
-    UpdateTaskStatusResult,
+    executor_status, CancelJobParams, CancelJobResult, ExecuteQueryParams,
+    ExecuteQueryResult, ExecutorHeartbeat, ExecutorStatus, 
ExecutorStoppedParams,
+    ExecutorStoppedResult, GetFileMetadataParams, GetFileMetadataResult,
+    GetJobStatusParams, GetJobStatusResult, HeartBeatParams, HeartBeatResult,
+    PollWorkParams, PollWorkResult, RegisterExecutorParams, 
RegisterExecutorResult,
+    UpdateTaskStatusParams, UpdateTaskStatusResult,
 };
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::serde::AsExecutionPlan;
@@ -37,7 +38,7 @@ use datafusion::datasource::file_format::FileFormat;
 use datafusion_proto::logical_plan::AsLogicalPlan;
 use datafusion_proto::protobuf::FileType;
 use futures::TryStreamExt;
-use log::{debug, error, info, trace, warn};
+use log::{debug, error, info, warn};
 
 // use http_body::Body;
 use std::convert::TryInto;
@@ -72,6 +73,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         } = request.into_inner()
         {
             debug!("Received poll_work request for {:?}", metadata);
+            // We might receive buggy poll work requests from dead executors.
+            if self
+                .state
+                .executor_manager
+                .is_dead_executor(&metadata.id.clone())
+            {
+                let error_msg = format!(
+                    "Receive buggy poll work request from dead Executor {}",
+                    metadata.id.clone()
+                );
+                warn!("{}", error_msg);
+                return Err(Status::internal(error_msg));
+            }
             let metadata = ExecutorMetadata {
                 id: metadata.id,
                 host: metadata
@@ -90,7 +104,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
                     .duration_since(UNIX_EPOCH)
                     .expect("Time went backwards")
                     .as_secs(),
-                state: None,
+                metrics: vec![],
+                status: Some(ExecutorStatus {
+                    status: 
Some(executor_status::Status::Active("".to_string())),
+                }),
             };
 
             self.state
@@ -221,17 +238,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         &self,
         request: Request<HeartBeatParams>,
     ) -> Result<Response<HeartBeatResult>, Status> {
-        let HeartBeatParams { executor_id, state } = request.into_inner();
-
+        let HeartBeatParams {
+            executor_id,
+            metrics,
+            status,
+        } = request.into_inner();
         debug!("Received heart beat request for {:?}", executor_id);
-        trace!("Related executor state is {:?}", state);
         let executor_heartbeat = ExecutorHeartbeat {
             executor_id,
             timestamp: SystemTime::now()
                 .duration_since(UNIX_EPOCH)
                 .expect("Time went backwards")
                 .as_secs(),
-            state,
+            metrics,
+            status,
         };
         self.state
             .executor_manager
@@ -474,6 +494,36 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
         }
     }
 
+    async fn executor_stopped(
+        &self,
+        request: Request<ExecutorStoppedParams>,
+    ) -> Result<Response<ExecutorStoppedResult>, Status> {
+        let ExecutorStoppedParams {
+            executor_id,
+            reason,
+        } = request.into_inner();
+        info!(
+            "Received executor stopped request from Executor {} with reason 
'{}'",
+            executor_id, reason
+        );
+
+        let executor_manager = self.state.executor_manager.clone();
+        let event_sender = 
self.query_stage_event_loop.get_sender().map_err(|e| {
+            let msg = format!("Get query stage event loop error due to {:?}", 
e);
+            error!("{}", msg);
+            Status::internal(msg)
+        })?;
+        Self::remove_executor(executor_manager, event_sender, &executor_id, 
Some(reason))
+            .await
+            .map_err(|e| {
+                let msg = format!("Error to remove executor in Scheduler due 
to {:?}", e);
+                error!("{}", msg);
+                Status::internal(msg)
+            })?;
+
+        Ok(Response::new(ExecutorStoppedResult {}))
+    }
+
     async fn cancel_job(
         &self,
         request: Request<CancelJobParams>,
@@ -487,10 +537,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
             .await
             .map_err(|e| {
                 let msg = format!("Error cancelling job {}: {:?}", job_id, e);
+
                 error!("{}", msg);
                 Status::internal(msg)
             })?;
-
         Ok(Response::new(CancelJobResult { cancelled: true }))
     }
 }
@@ -498,6 +548,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerGrpc
 #[cfg(all(test, feature = "sled"))]
 mod test {
     use std::sync::Arc;
+    use std::time::Duration;
 
     use datafusion::execution::context::default_session_builder;
     use datafusion_proto::protobuf::LogicalPlanNode;
@@ -505,12 +556,14 @@ mod test {
 
     use ballista_core::error::BallistaError;
     use ballista_core::serde::protobuf::{
-        executor_registration::OptionalHost, ExecutorRegistration, 
PhysicalPlanNode,
-        PollWorkParams,
+        executor_registration::OptionalHost, executor_status, 
ExecutorRegistration,
+        ExecutorStatus, ExecutorStoppedParams, HeartBeatParams, 
PhysicalPlanNode,
+        PollWorkParams, RegisterExecutorParams,
     };
     use ballista_core::serde::scheduler::ExecutorSpecification;
     use ballista_core::serde::BallistaCodec;
 
+    use crate::state::executor_manager::DEFAULT_EXECUTOR_TIMEOUT_SECONDS;
     use crate::state::{backend::standalone::StandaloneClient, SchedulerState};
 
     use super::{SchedulerGrpc, SchedulerServer};
@@ -599,4 +652,165 @@ mod test {
 
         Ok(())
     }
+
+    #[tokio::test]
+    async fn test_stop_executor() -> Result<(), BallistaError> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+        let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
+            SchedulerServer::new(
+                "localhost:50050".to_owned(),
+                state_storage.clone(),
+                BallistaCodec::default(),
+            );
+        scheduler.init().await?;
+
+        let exec_meta = ExecutorRegistration {
+            id: "abc".to_owned(),
+            optional_host: 
Some(OptionalHost::Host("http://localhost:8080".to_owned())),
+            port: 0,
+            grpc_port: 0,
+            specification: Some(ExecutorSpecification { task_slots: 2 
}.into()),
+        };
+
+        let request: Request<RegisterExecutorParams> =
+            Request::new(RegisterExecutorParams {
+                metadata: Some(exec_meta.clone()),
+            });
+        let response = scheduler
+            .register_executor(request)
+            .await
+            .expect("Received error response")
+            .into_inner();
+
+        // registration should success
+        assert!(response.success);
+
+        let state = scheduler.state.clone();
+        // executor should be registered
+        let stored_executor = state
+            .executor_manager
+            .get_executor_metadata("abc")
+            .await
+            .expect("getting executor");
+
+        assert_eq!(stored_executor.grpc_port, 0);
+        assert_eq!(stored_executor.port, 0);
+        assert_eq!(stored_executor.specification.task_slots, 2);
+        assert_eq!(stored_executor.host, "http://localhost:8080".to_owned());
+
+        let request: Request<ExecutorStoppedParams> =
+            Request::new(ExecutorStoppedParams {
+                executor_id: "abc".to_owned(),
+                reason: "test_stop".to_owned(),
+            });
+
+        let _response = scheduler
+            .executor_stopped(request)
+            .await
+            .expect("Received error response")
+            .into_inner();
+
+        // executor should be registered
+        let _stopped_executor = state
+            .executor_manager
+            .get_executor_metadata("abc")
+            .await
+            .expect("getting executor");
+
+        // executor should be marked to dead
+        assert!(state.executor_manager.is_dead_executor("abc"));
+
+        let active_executors = state
+            .executor_manager
+            .get_alive_executors_within_one_minute();
+        assert!(active_executors.is_empty());
+
+        let expired_executors = state.executor_manager.get_expired_executors();
+        assert!(expired_executors.is_empty());
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    #[ignore]
+    async fn test_expired_executor() -> Result<(), BallistaError> {
+        let state_storage = Arc::new(StandaloneClient::try_new_temporary()?);
+        let mut scheduler: SchedulerServer<LogicalPlanNode, PhysicalPlanNode> =
+            SchedulerServer::new(
+                "localhost:50050".to_owned(),
+                state_storage.clone(),
+                BallistaCodec::default(),
+            );
+        scheduler.init().await?;
+
+        let exec_meta = ExecutorRegistration {
+            id: "abc".to_owned(),
+            optional_host: 
Some(OptionalHost::Host("http://localhost:8080".to_owned())),
+            port: 0,
+            grpc_port: 0,
+            specification: Some(ExecutorSpecification { task_slots: 2 
}.into()),
+        };
+
+        let request: Request<RegisterExecutorParams> =
+            Request::new(RegisterExecutorParams {
+                metadata: Some(exec_meta.clone()),
+            });
+        let response = scheduler
+            .register_executor(request)
+            .await
+            .expect("Received error response")
+            .into_inner();
+
+        // registration should success
+        assert!(response.success);
+
+        let state = scheduler.state.clone();
+        // executor should be registered
+        let stored_executor = state
+            .executor_manager
+            .get_executor_metadata("abc")
+            .await
+            .expect("getting executor");
+
+        assert_eq!(stored_executor.grpc_port, 0);
+        assert_eq!(stored_executor.port, 0);
+        assert_eq!(stored_executor.specification.task_slots, 2);
+        assert_eq!(stored_executor.host, "http://localhost:8080".to_owned());
+
+        // heartbeat from the executor
+        let request: Request<HeartBeatParams> = Request::new(HeartBeatParams {
+            executor_id: "abc".to_owned(),
+            metrics: vec![],
+            status: Some(ExecutorStatus {
+                status: Some(executor_status::Status::Active("".to_string())),
+            }),
+        });
+
+        let _response = scheduler
+            .heart_beat_from_executor(request)
+            .await
+            .expect("Received error response")
+            .into_inner();
+
+        let active_executors = state
+            .executor_manager
+            .get_alive_executors_within_one_minute();
+        assert_eq!(active_executors.len(), 1);
+
+        let expired_executors = state.executor_manager.get_expired_executors();
+        assert!(expired_executors.is_empty());
+
+        // simulate the heartbeat timeout
+        
tokio::time::sleep(Duration::from_secs(DEFAULT_EXECUTOR_TIMEOUT_SECONDS)).await;
+        tokio::time::sleep(Duration::from_secs(3)).await;
+
+        // executor should be marked to dead
+        assert!(state.executor_manager.is_dead_executor("abc"));
+
+        let active_executors = state
+            .executor_manager
+            .get_alive_executors_within_one_minute();
+        assert!(active_executors.is_empty());
+        Ok(())
+    }
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs 
b/ballista/rust/scheduler/src/scheduler_server/mod.rs
index 19ce8680..fed268cc 100644
--- a/ballista/rust/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs
@@ -16,22 +16,26 @@
 // under the License.
 
 use std::sync::Arc;
-use std::time::{SystemTime, UNIX_EPOCH};
+use std::time::{Duration, SystemTime, UNIX_EPOCH};
 
 use ballista_core::config::TaskSchedulingPolicy;
 use ballista_core::error::Result;
-use ballista_core::event_loop::EventLoop;
-use ballista_core::serde::protobuf::TaskStatus;
+use ballista_core::event_loop::{EventLoop, EventSender};
+use ballista_core::serde::protobuf::{StopExecutorParams, TaskStatus};
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use datafusion::execution::context::{default_session_builder, SessionState};
 use datafusion::logical_plan::LogicalPlan;
 use datafusion::prelude::{SessionConfig, SessionContext};
 use datafusion_proto::logical_plan::AsLogicalPlan;
 
+use log::{error, warn};
+
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::scheduler_server::query_stage_scheduler::QueryStageScheduler;
 use crate::state::backend::StateBackendClient;
-use crate::state::executor_manager::ExecutorReservation;
+use crate::state::executor_manager::{
+    ExecutorManager, ExecutorReservation, DEFAULT_EXECUTOR_TIMEOUT_SECONDS,
+};
 use crate::state::SchedulerState;
 
 // include the generated protobuf source as a submodule
@@ -127,6 +131,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
     pub async fn init(&mut self) -> Result<()> {
         self.state.init().await?;
         self.query_stage_event_loop.start()?;
+        self.expire_dead_executors()?;
 
         Ok(())
     }
@@ -154,6 +159,15 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
         executor_id: &str,
         tasks_status: Vec<TaskStatus>,
     ) -> Result<()> {
+        // We might receive buggy task updates from dead executors.
+        if self.state.executor_manager.is_dead_executor(executor_id) {
+            let error_msg = format!(
+                "Receive buggy tasks status from dead Executor {}, task status 
update ignored.",
+                executor_id
+            );
+            warn!("{}", error_msg);
+            return Ok(());
+        }
         self.query_stage_event_loop
             .get_sender()?
             .post_event(QueryStageSchedulerEvent::TaskUpdating(
@@ -172,6 +186,91 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> SchedulerServer<T
             
.post_event(QueryStageSchedulerEvent::ReservationOffering(reservations))
             .await
     }
+
+    /// Spawn an async task which periodically check the active executors' 
status and
+    /// expire the dead executors
+    fn expire_dead_executors(&self) -> Result<()> {
+        let state = self.state.clone();
+        let event_sender = self.query_stage_event_loop.get_sender()?;
+        tokio::task::spawn(async move {
+            loop {
+                let expired_executors = 
state.executor_manager.get_expired_executors();
+                for expired in expired_executors {
+                    let executor_id = expired.executor_id.clone();
+                    let executor_manager = state.executor_manager.clone();
+                    let stop_reason = format!(
+                        "Executor {} heartbeat timed out after {}s",
+                        executor_id.clone(),
+                        DEFAULT_EXECUTOR_TIMEOUT_SECONDS
+                    );
+                    warn!("{}", stop_reason.clone());
+                    let sender_clone = event_sender.clone();
+                    Self::remove_executor(
+                        executor_manager,
+                        sender_clone,
+                        &executor_id,
+                        Some(stop_reason.clone()),
+                    )
+                    .await
+                    .unwrap_or_else(|e| {
+                        let msg = format!(
+                            "Error to remove Executor in Scheduler due to 
{:?}",
+                            e
+                        );
+                        error!("{}", msg);
+                    });
+
+                    match 
state.executor_manager.get_client(&executor_id).await {
+                        Ok(mut client) => {
+                            tokio::task::spawn(async move {
+                                match client
+                                    .stop_executor(StopExecutorParams {
+                                        reason: stop_reason,
+                                        force: true,
+                                    })
+                                    .await
+                                {
+                                    Err(error) => {
+                                        warn!(
+                                            "Failed to send stop_executor rpc 
due to, {}",
+                                            error
+                                        );
+                                    }
+                                    Ok(_value) => {}
+                                }
+                            });
+                        }
+                        Err(_) => {
+                            warn!("Executor is already dead, failed to connect 
to Executor {}", executor_id);
+                        }
+                    }
+                }
+                
tokio::time::sleep(Duration::from_secs(DEFAULT_EXECUTOR_TIMEOUT_SECONDS))
+                    .await;
+            }
+        });
+        Ok(())
+    }
+
+    pub(crate) async fn remove_executor(
+        executor_manager: ExecutorManager,
+        event_sender: EventSender<QueryStageSchedulerEvent>,
+        executor_id: &str,
+        reason: Option<String>,
+    ) -> Result<()> {
+        // Update the executor manager immediately here
+        executor_manager
+            .remove_executor(executor_id, reason.clone())
+            .await?;
+
+        event_sender
+            .post_event(QueryStageSchedulerEvent::ExecutorLost(
+                executor_id.to_owned(),
+                reason,
+            ))
+            .await?;
+        Ok(())
+    }
 }
 
 #[cfg(all(test, feature = "sled"))]
@@ -311,9 +410,7 @@ mod test {
         Ok(())
     }
 
-    /// This test will exercise the push-based scheduling. We setup our 
scheduler server
-    /// with `SchedulerEventObserver` to listen to `SchedulerServerEvents` and 
then just immediately
-    /// complete the tasks.
+    /// This test will exercise the push-based scheduling.
     #[tokio::test]
     async fn test_push_scheduling() -> Result<()> {
         let plan = test_plan();
diff --git 
a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs 
b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
index 53fe38f5..7d186fcd 100644
--- a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -141,7 +141,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>
                 self.state.task_manager.fail_running_job(&job_id).await?;
             }
             QueryStageSchedulerEvent::JobUpdated(job_id) => {
-                error!("Job {} Updated", job_id);
+                info!("Job {} Updated", job_id);
                 self.state.task_manager.update_job(&job_id).await?;
             }
             QueryStageSchedulerEvent::TaskUpdating(executor_id, tasks_status) 
=> {
@@ -166,7 +166,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>
                     }
                     Err(e) => {
                         error!(
-                            "Failed to update {} task statuses for executor 
{}: {:?}",
+                            "Failed to update {} task statuses for Executor 
{}: {:?}",
                             num_status, executor_id, e
                         );
                         // TODO error handling
@@ -183,6 +183,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>
                         .await?;
                 }
             }
+            QueryStageSchedulerEvent::ExecutorLost(executor_id, _) => {
+                self.state
+                    .task_manager
+                    .executor_lost(&executor_id)
+                    .await
+                    .unwrap_or_else(|e| {
+                        let msg = format!(
+                            "TaskManager error to handle Executor {} lost: {}",
+                            executor_id, e
+                        );
+                        error!("{}", msg);
+                    });
+            }
         }
 
         Ok(())
diff --git a/ballista/rust/scheduler/src/state/execution_graph.rs 
b/ballista/rust/scheduler/src/state/execution_graph.rs
index 596272d6..08e97d52 100644
--- a/ballista/rust/scheduler/src/state/execution_graph.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph.rs
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use std::convert::TryInto;
 use std::fmt::{Debug, Formatter};
 use std::sync::Arc;
@@ -45,7 +45,8 @@ use crate::display::print_stage_metrics;
 use crate::planner::DistributedPlanner;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::state::execution_graph::execution_stage::{
-    CompletedStage, ExecutionStage, FailedStage, ResolvedStage, 
UnresolvedStage,
+    CompletedStage, ExecutionStage, FailedStage, ResolvedStage, StageOutput,
+    UnresolvedStage,
 };
 
 mod execution_stage;
@@ -152,6 +153,10 @@ impl ExecutionGraph {
         self.status.clone()
     }
 
+    pub fn stage_count(&self) -> usize {
+        self.stages.len()
+    }
+
     /// An ExecutionGraph is complete if all its stages are complete
     pub fn complete(&self) -> bool {
         self.stages
@@ -294,12 +299,12 @@ impl ExecutionGraph {
                         output_links,
                     )?);
                 } else {
-                    return Err(BallistaError::Internal(format!(
+                    warn!(
                         "Stage {}/{} is not in running when updating the 
status of tasks {:?}",
                         job_id,
                         stage_id,
                         stage_task_statuses.into_iter().map(|task_status| 
task_status.task_id.map(|task_id| task_id.partition_id)).collect::<Vec<_>>(),
-                    )));
+                    );
                 }
             } else {
                 return Err(BallistaError::Internal(format!(
@@ -485,8 +490,139 @@ impl ExecutionGraph {
         self.output_locations.clone()
     }
 
+    /// Reset running and completed stages on a given executor
+    /// This will first check the unresolved/resolved/running stages and reset 
the running tasks and completed tasks.
+    /// Then it will check the completed stage and whether there are running 
parent stages need to read shuffle from it.
+    /// If yes, reset the complete tasks and roll back the resolved shuffle 
recursively.
+    ///
+    /// Returns the reset stage ids
+    pub fn reset_stages(&mut self, executor_id: &str) -> 
Result<HashSet<usize>> {
+        let mut reset = HashSet::new();
+        loop {
+            let reset_stage = self.reset_stages_internal(executor_id)?;
+            if !reset_stage.is_empty() {
+                reset.extend(reset_stage.iter());
+            } else {
+                return Ok(reset);
+            }
+        }
+    }
+
+    fn reset_stages_internal(&mut self, executor_id: &str) -> 
Result<HashSet<usize>> {
+        let mut reset_stage = HashSet::new();
+        let job_id = self.job_id.clone();
+        let mut stage_events = vec![];
+        let mut resubmit_inputs: HashSet<usize> = HashSet::new();
+        let mut empty_inputs: HashMap<usize, StageOutput> = HashMap::new();
+
+        // check the unresolved, resolved and running stages
+        self.stages
+            .iter_mut()
+            .for_each(|(stage_id, stage)| {
+                let stage_inputs = match stage {
+                    ExecutionStage::UnResolved(stage) => {
+                        &mut stage.inputs
+                    }
+                    ExecutionStage::Resolved(stage) => {
+                        &mut stage.inputs
+                    }
+                    ExecutionStage::Running(stage) => {
+                        let reset = stage.reset_tasks(executor_id);
+                        if reset > 0 {
+                            warn!(
+                        "Reset {} tasks for running job/stage {}/{} on lost 
Executor {}",
+                        reset, job_id, stage_id, executor_id
+                        );
+                            reset_stage.insert(*stage_id);
+                        }
+                        &mut stage.inputs
+                    }
+                    _ => &mut empty_inputs
+                };
+
+                // For each stage input, check whether there are input 
locations match that executor
+                // and calculate the resubmit input stages if the input stages 
are completed.
+                let mut rollback_stage = false;
+                stage_inputs.iter_mut().for_each(|(input_stage_id, 
stage_output)| {
+                    let mut match_found = false;
+                    stage_output.partition_locations.iter_mut().for_each(
+                        |(_partition, locs)| {
+                            let indexes = locs
+                                .iter()
+                                .enumerate()
+                                .filter_map(|(idx, loc)| {
+                                    (loc.executor_meta.id == 
executor_id).then(|| idx)
+                                })
+                                .collect::<Vec<_>>();
+
+                            // remove the matched partition locations
+                            if !indexes.is_empty() {
+                                for idx in &indexes {
+                                    locs.remove(*idx);
+                                }
+                                match_found = true;
+                            }
+                        },
+                    );
+                    if match_found {
+                        stage_output.complete = false;
+                        rollback_stage = true;
+                        resubmit_inputs.insert(*input_stage_id);
+                    }
+                });
+
+                if rollback_stage {
+                    match stage {
+                        ExecutionStage::Resolved(_) => {
+                            
stage_events.push(StageEvent::RollBackResolvedStage(*stage_id));
+                            warn!(
+                            "Roll back resolved job/stage {}/{} and change 
ShuffleReaderExec back to UnresolvedShuffleExec",
+                            job_id, stage_id);
+                            reset_stage.insert(*stage_id);
+                        },
+                        ExecutionStage::Running(_) => {
+                            
stage_events.push(StageEvent::RollBackRunningStage(*stage_id));
+                            warn!(
+                            "Roll back running job/stage {}/{} and change 
ShuffleReaderExec back to UnresolvedShuffleExec",
+                            job_id, stage_id);
+                            reset_stage.insert(*stage_id);
+                        },
+                        _ => {},
+                    }
+                }
+            });
+
+        // check and reset the complete stages
+        if !resubmit_inputs.is_empty() {
+            self.stages
+                .iter_mut()
+                .filter(|(stage_id, _stage)| 
resubmit_inputs.contains(stage_id))
+                .filter_map(|(_stage_id, stage)| {
+                    if let ExecutionStage::Completed(completed) = stage {
+                        Some(completed)
+                    } else {
+                        None
+                    }
+                })
+                .for_each(|stage| {
+                    let reset = stage.reset_tasks(executor_id);
+                    if reset > 0 {
+                        stage_events
+                            
.push(StageEvent::ReRunCompletedStage(stage.stage_id));
+                        reset_stage.insert(stage.stage_id);
+                        warn!(
+                            "Reset {} tasks for completed job/stage {}/{} on 
lost Executor {}",
+                            reset, job_id, stage.stage_id, executor_id
+                        )
+                    }
+                });
+        }
+        self.processing_stage_events(stage_events)?;
+        Ok(reset_stage)
+    }
+
     /// Processing stage events for stage state changing
-    fn processing_stage_events(
+    pub fn processing_stage_events(
         &mut self,
         events: Vec<StageEvent>,
     ) -> Result<Option<QueryStageSchedulerEvent>> {
@@ -505,6 +641,15 @@ impl ExecutionGraph {
                     job_err_msg = format!("{}{}\n", job_err_msg, &err_msg);
                     self.fail_stage(stage_id, err_msg);
                 }
+                StageEvent::RollBackRunningStage(stage_id) => {
+                    self.rollback_running_stage(stage_id)?;
+                }
+                StageEvent::RollBackResolvedStage(stage_id) => {
+                    self.rollback_resolved_stage(stage_id)?;
+                }
+                StageEvent::ReRunCompletedStage(stage_id) => {
+                    self.rerun_completed_stage(stage_id);
+                }
             }
         }
 
@@ -581,6 +726,54 @@ impl ExecutionGraph {
         }
     }
 
+    /// Convert running stage to be unresolved
+    fn rollback_running_stage(&mut self, stage_id: usize) -> Result<bool> {
+        if let Some(ExecutionStage::Running(stage)) = 
self.stages.remove(&stage_id) {
+            self.stages
+                .insert(stage_id, 
ExecutionStage::UnResolved(stage.to_unresolved()?));
+            Ok(true)
+        } else {
+            warn!(
+                "Fail to find a running stage {}/{} to rollback",
+                self.job_id(),
+                stage_id
+            );
+            Ok(false)
+        }
+    }
+
+    /// Convert resolved stage to be unresolved
+    fn rollback_resolved_stage(&mut self, stage_id: usize) -> Result<bool> {
+        if let Some(ExecutionStage::Resolved(stage)) = 
self.stages.remove(&stage_id) {
+            self.stages
+                .insert(stage_id, 
ExecutionStage::UnResolved(stage.to_unresolved()?));
+            Ok(true)
+        } else {
+            warn!(
+                "Fail to find a resolved stage {}/{} to rollback",
+                self.job_id(),
+                stage_id
+            );
+            Ok(false)
+        }
+    }
+
+    /// Convert completed stage to be running
+    fn rerun_completed_stage(&mut self, stage_id: usize) -> bool {
+        if let Some(ExecutionStage::Completed(stage)) = 
self.stages.remove(&stage_id) {
+            self.stages
+                .insert(stage_id, ExecutionStage::Running(stage.to_running()));
+            true
+        } else {
+            warn!(
+                "Fail to find a completed stage {}/{} to rerun",
+                self.job_id(),
+                stage_id
+            );
+            false
+        }
+    }
+
     /// fail job with error message
     pub fn fail_job(&mut self, error: String) {
         self.status = JobStatus {
@@ -790,6 +983,7 @@ impl ExecutionStageBuilder {
                     stage,
                     partitioning,
                     output_links,
+                    HashMap::new(),
                 ))
             } else {
                 ExecutionStage::UnResolved(UnresolvedStage::new(
@@ -848,6 +1042,9 @@ pub enum StageEvent {
     StageResolved(usize),
     StageCompleted(usize),
     StageFailed(usize, String),
+    RollBackRunningStage(usize),
+    RollBackResolvedStage(usize),
+    ReRunCompletedStage(usize),
 }
 
 /// Represents the basic unit of work for the Ballista executor. Will execute
@@ -912,10 +1109,10 @@ mod test {
     use datafusion::test_util::scan_empty;
 
     use ballista_core::error::Result;
-    use ballista_core::serde::protobuf::{self, job_status, task_status};
+    use ballista_core::serde::protobuf::{self, job_status, task_status, 
TaskStatus};
     use ballista_core::serde::scheduler::{ExecutorMetadata, 
ExecutorSpecification};
 
-    use crate::state::execution_graph::ExecutionGraph;
+    use crate::state::execution_graph::{ExecutionGraph, Task};
 
     #[tokio::test]
     async fn test_drain_tasks() -> Result<()> {
@@ -989,46 +1186,163 @@ mod test {
         Ok(())
     }
 
-    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
-        let executor = test_executor();
-        let job_id = graph.job_id().to_owned();
-        while let Some(task) = graph.pop_next_task("executor-id")? {
-            let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
-
-            let num_partitions = task
-                .output_partitioning
-                .map(|p| p.partition_count())
-                .unwrap_or(1);
-
-            for partition_id in 0..num_partitions {
-                partitions.push(protobuf::ShuffleWritePartition {
-                    partition_id: partition_id as u64,
-                    path: format!(
-                        "/{}/{}/{}",
-                        task.partition.job_id,
-                        task.partition.stage_id,
-                        task.partition.partition_id
-                    ),
-                    num_batches: 1,
-                    num_rows: 1,
-                    num_bytes: 1,
-                })
-            }
+    #[tokio::test]
+    async fn test_reset_completed_stage() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut join_graph = test_join_plan(4).await;
 
-            // Complete the task
-            let task_status = protobuf::TaskStatus {
-                status: 
Some(task_status::Status::Completed(protobuf::CompletedTask {
-                    executor_id: "executor-1".to_owned(),
-                    partitions,
-                })),
-                metrics: vec![],
-                task_id: Some(protobuf::PartitionId {
-                    job_id: job_id.clone(),
-                    stage_id: task.partition.stage_id as u32,
-                    partition_id: task.partition.partition_id as u32,
-                }),
-            };
+        assert_eq!(join_graph.stage_count(), 5);
+        assert_eq!(join_graph.available_tasks(), 0);
+
+        // Call revive to move the two leaf Resolved stages to Running
+        join_graph.revive();
+
+        assert_eq!(join_graph.stage_count(), 5);
+        assert_eq!(join_graph.available_tasks(), 2);
+
+        // Complete the first stage
+        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            join_graph.update_task_status(&executor1, vec![task_status])?;
+        }
+
+        // Complete the second stage
+        if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
+            let task_status = mock_completed_task(task, &executor2.id);
+            join_graph.update_task_status(&executor2, vec![task_status])?;
+        }
+
+        join_graph.revive();
+        // There are 4 tasks pending schedule for the 3rd stage
+        assert_eq!(join_graph.available_tasks(), 4);
+
+        // Complete 1 task
+        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            join_graph.update_task_status(&executor1, vec![task_status])?;
+        }
+        // Mock 1 running task
+        let _task = join_graph.pop_next_task(&executor1.id)?;
+
+        let reset = join_graph.reset_stages(&executor1.id)?;
+
+        // Two stages were reset, 1 Running stage rollback to Unresolved and 1 
Completed stage move to Running
+        assert_eq!(reset.len(), 2);
+        assert_eq!(join_graph.available_tasks(), 1);
+
+        drain_tasks(&mut join_graph)?;
+        assert!(join_graph.complete(), "Failed to complete join plan");
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_reset_resolved_stage() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut join_graph = test_join_plan(4).await;
+
+        assert_eq!(join_graph.stage_count(), 5);
+        assert_eq!(join_graph.available_tasks(), 0);
+
+        // Call revive to move the two leaf Resolved stages to Running
+        join_graph.revive();
+
+        assert_eq!(join_graph.stage_count(), 5);
+        assert_eq!(join_graph.available_tasks(), 2);
+
+        // Complete the first stage
+        if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            join_graph.update_task_status(&executor1, vec![task_status])?;
+        }
+
+        // Complete the second stage
+        if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
+            let task_status = mock_completed_task(task, &executor2.id);
+            join_graph.update_task_status(&executor2, vec![task_status])?;
+        }
+
+        // There are 0 tasks pending schedule now
+        assert_eq!(join_graph.available_tasks(), 0);
+
+        let reset = join_graph.reset_stages(&executor1.id)?;
+
+        // Two stages were reset, 1 Resolved stage rollback to Unresolved and 
1 Completed stage move to Running
+        assert_eq!(reset.len(), 2);
+        assert_eq!(join_graph.available_tasks(), 1);
+
+        drain_tasks(&mut join_graph)?;
+        assert!(join_graph.complete(), "Failed to complete join plan");
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_task_update_after_reset_stage() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+
+        assert_eq!(agg_graph.stage_count(), 2);
+        assert_eq!(agg_graph.available_tasks(), 0);
+
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        assert_eq!(agg_graph.stage_count(), 2);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        // Complete the first stage
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status])?;
+        }
+
+        // 1st task in the second stage
+        if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+            let task_status = mock_completed_task(task, &executor2.id);
+            agg_graph.update_task_status(&executor2, vec![task_status])?;
+        }
+
+        // 2rd task in the second stage
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status])?;
+        }
+
+        // 3rd task in the second stage, scheduled but not completed
+        let task = agg_graph.pop_next_task(&executor1.id)?;
 
+        // There is 1 task pending schedule now
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        let reset = agg_graph.reset_stages(&executor1.id)?;
+
+        // 3rd task status update comes later.
+        let task_status = mock_completed_task(task.unwrap(), &executor1.id);
+        agg_graph.update_task_status(&executor1, vec![task_status])?;
+
+        // Two stages were reset, 1 Running stage rollback to Unresolved and 1 
Completed stage move to Running
+        assert_eq!(reset.len(), 2);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        // Call the reset again
+        let reset = agg_graph.reset_stages(&executor1.id)?;
+        assert_eq!(reset.len(), 0);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.complete(), "Failed to complete agg plan");
+
+        Ok(())
+    }
+
+    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
+        let executor = mock_executor("executor-id1".to_string());
+        while let Some(task) = graph.pop_next_task(&executor.id)? {
+            let task_status = mock_completed_task(task, &executor.id);
             graph.update_task_status(&executor, vec![task_status])?;
         }
 
@@ -1179,13 +1493,51 @@ mod test {
         graph
     }
 
-    fn test_executor() -> ExecutorMetadata {
+    fn mock_executor(executor_id: String) -> ExecutorMetadata {
         ExecutorMetadata {
-            id: "executor-2".to_string(),
+            id: executor_id,
             host: "localhost2".to_string(),
             port: 8080,
             grpc_port: 9090,
             specification: ExecutorSpecification { task_slots: 1 },
         }
     }
+
+    fn mock_completed_task(task: Task, executor_id: &str) -> TaskStatus {
+        let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
+
+        let num_partitions = task
+            .output_partitioning
+            .map(|p| p.partition_count())
+            .unwrap_or(1);
+
+        for partition_id in 0..num_partitions {
+            partitions.push(protobuf::ShuffleWritePartition {
+                partition_id: partition_id as u64,
+                path: format!(
+                    "/{}/{}/{}",
+                    task.partition.job_id,
+                    task.partition.stage_id,
+                    task.partition.partition_id
+                ),
+                num_batches: 1,
+                num_rows: 1,
+                num_bytes: 1,
+            })
+        }
+
+        // Complete the task
+        protobuf::TaskStatus {
+            status: 
Some(task_status::Status::Completed(protobuf::CompletedTask {
+                executor_id: executor_id.to_owned(),
+                partitions,
+            })),
+            metrics: vec![],
+            task_id: Some(protobuf::PartitionId {
+                job_id: task.partition.job_id.clone(),
+                stage_id: task.partition.stage_id as u32,
+                partition_id: task.partition.partition_id as u32,
+            }),
+        }
+    }
 }
diff --git 
a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs 
b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
index 3e3aee00..b8d590cf 100644
--- a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
@@ -29,7 +29,9 @@ use log::{debug, warn};
 
 use ballista_core::error::{BallistaError, Result};
 use 
ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
-use ballista_core::serde::protobuf::{self, OperatorMetricsSet};
+use ballista_core::serde::protobuf::{
+    self, CompletedTask, FailedTask, GraphStageInput, OperatorMetricsSet,
+};
 use ballista_core::serde::protobuf::{task_status, RunningTask};
 use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
 use ballista_core::serde::scheduler::PartitionLocation;
@@ -98,6 +100,8 @@ pub(super) struct ResolvedStage {
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the 
`ExecutionGraph`
     pub(super) output_links: Vec<usize>,
+    /// Represents the outputs from this stage's child stages.
+    pub(super) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(super) plan: Arc<dyn ExecutionPlan>,
 }
@@ -119,6 +123,8 @@ pub(super) struct RunningStage {
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the 
`ExecutionGraph`
     pub(super) output_links: Vec<usize>,
+    /// Represents the outputs from this stage's child stages.
+    pub(super) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(super) plan: Arc<dyn ExecutionPlan>,
     /// Status of each already scheduled task. If status is None, the 
partition has not yet been scheduled
@@ -140,6 +146,8 @@ pub(super) struct CompletedStage {
     /// Stage ID of the stage that will take this stages outputs as inputs.
     /// If `output_links` is empty then this the final stage in the 
`ExecutionGraph`
     pub(super) output_links: Vec<usize>,
+    /// Represents the outputs from this stage's child stages.
+    pub(super) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(super) plan: Arc<dyn ExecutionPlan>,
     /// Status of each already scheduled task.
@@ -177,10 +185,10 @@ impl UnresolvedStage {
         plan: Arc<dyn ExecutionPlan>,
         output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
-        child_stages: Vec<usize>,
+        child_stage_ids: Vec<usize>,
     ) -> Self {
         let mut inputs: HashMap<usize, StageOutput> = HashMap::new();
-        for input_stage_id in child_stages {
+        for input_stage_id in child_stage_ids {
             inputs.insert(input_stage_id, StageOutput::new());
         }
 
@@ -193,6 +201,22 @@ impl UnresolvedStage {
         }
     }
 
+    pub(super) fn new_with_inputs(
+        stage_id: usize,
+        plan: Arc<dyn ExecutionPlan>,
+        output_partitioning: Option<Partitioning>,
+        output_links: Vec<usize>,
+        inputs: HashMap<usize, StageOutput>,
+    ) -> Self {
+        Self {
+            stage_id,
+            output_partitioning,
+            output_links,
+            inputs,
+            plan,
+        }
+    }
+
     /// Add input partitions published from an input stage.
     pub(super) fn add_input_partitions(
         &mut self,
@@ -239,6 +263,7 @@ impl UnresolvedStage {
             plan,
             self.output_partitioning.clone(),
             self.output_links.clone(),
+            self.inputs.clone(),
         ))
     }
 
@@ -260,32 +285,7 @@ impl UnresolvedStage {
             plan.schema().as_ref(),
         )?;
 
-        let mut inputs: HashMap<usize, StageOutput> = HashMap::new();
-        for input in stage.inputs {
-            let stage_id = input.stage_id as usize;
-
-            let outputs = input
-                .partition_locations
-                .into_iter()
-                .map(|loc| {
-                    let partition = loc.partition as usize;
-                    let locations = loc
-                        .partition_location
-                        .into_iter()
-                        .map(|l| l.try_into())
-                        .collect::<Result<Vec<_>>>()?;
-                    Ok((partition, locations))
-                })
-                .collect::<Result<HashMap<usize, Vec<PartitionLocation>>>>()?;
-
-            inputs.insert(
-                stage_id,
-                StageOutput {
-                    partition_locations: outputs,
-                    complete: input.complete,
-                },
-            );
-        }
+        let inputs = decode_inputs(stage.inputs)?;
 
         Ok(UnresolvedStage {
             stage_id: stage.stage_id as usize,
@@ -304,26 +304,7 @@ impl UnresolvedStage {
         U::try_from_physical_plan(stage.plan, codec.physical_extension_codec())
             .and_then(|proto| proto.try_encode(&mut plan))?;
 
-        let mut inputs: Vec<protobuf::GraphStageInput> = vec![];
-        for (stage_id, output) in stage.inputs.into_iter() {
-            inputs.push(protobuf::GraphStageInput {
-                stage_id: stage_id as u32,
-                partition_locations: output
-                    .partition_locations
-                    .into_iter()
-                    .map(|(partition, locations)| {
-                        Ok(protobuf::TaskInputPartitions {
-                            partition: partition as u32,
-                            partition_location: locations
-                                .into_iter()
-                                .map(|l| l.try_into())
-                                .collect::<Result<Vec<_>>>()?,
-                        })
-                    })
-                    .collect::<Result<Vec<_>>>()?,
-                complete: output.complete,
-            });
-        }
+        let inputs = encode_inputs(stage.inputs)?;
 
         let output_partitioning =
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
@@ -359,6 +340,7 @@ impl ResolvedStage {
         plan: Arc<dyn ExecutionPlan>,
         output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
+        inputs: HashMap<usize, StageOutput>,
     ) -> Self {
         let partitions = plan.output_partitioning().partition_count();
 
@@ -367,6 +349,7 @@ impl ResolvedStage {
             partitions,
             output_partitioning,
             output_links,
+            inputs,
             plan,
         }
     }
@@ -379,9 +362,24 @@ impl ResolvedStage {
             self.partitions,
             self.output_partitioning.clone(),
             self.output_links.clone(),
+            self.inputs.clone(),
         )
     }
 
+    /// Change to the unresolved state
+    pub(super) fn to_unresolved(&self) -> Result<UnresolvedStage> {
+        let new_plan = 
crate::planner::rollback_resolved_shuffles(self.plan.clone())?;
+
+        let unresolved = UnresolvedStage::new_with_inputs(
+            self.stage_id,
+            new_plan,
+            self.output_partitioning.clone(),
+            self.output_links.clone(),
+            self.inputs.clone(),
+        );
+        Ok(unresolved)
+    }
+
     pub(super) fn decode<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>(
         stage: protobuf::ResolvedStage,
         codec: &BallistaCodec<T, U>,
@@ -400,11 +398,14 @@ impl ResolvedStage {
             plan.schema().as_ref(),
         )?;
 
+        let inputs = decode_inputs(stage.inputs)?;
+
         Ok(ResolvedStage {
             stage_id: stage.stage_id as usize,
             partitions: stage.partitions as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as 
usize).collect(),
+            inputs,
             plan,
         })
     }
@@ -420,11 +421,14 @@ impl ResolvedStage {
         let output_partitioning =
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
 
+        let inputs = encode_inputs(stage.inputs)?;
+
         Ok(protobuf::ResolvedStage {
             stage_id: stage.stage_id as u64,
             partitions: stage.partitions as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as 
u32).collect(),
+            inputs,
             plan,
         })
     }
@@ -449,12 +453,14 @@ impl RunningStage {
         partitions: usize,
         output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
+        inputs: HashMap<usize, StageOutput>,
     ) -> Self {
         Self {
             stage_id,
             partitions,
             output_partitioning,
             output_links,
+            inputs,
             plan,
             task_statuses: vec![None; partitions],
             stage_metrics: None,
@@ -484,6 +490,7 @@ impl RunningStage {
             partitions: self.partitions,
             output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
+            inputs: self.inputs.clone(),
             plan: self.plan.clone(),
             task_statuses,
             stage_metrics,
@@ -509,9 +516,24 @@ impl RunningStage {
             self.plan.clone(),
             self.output_partitioning.clone(),
             self.output_links.clone(),
+            self.inputs.clone(),
         )
     }
 
+    /// Change to the unresolved state
+    pub(super) fn to_unresolved(&self) -> Result<UnresolvedStage> {
+        let new_plan = 
crate::planner::rollback_resolved_shuffles(self.plan.clone())?;
+
+        let unresolved = UnresolvedStage::new_with_inputs(
+            self.stage_id,
+            new_plan,
+            self.output_partitioning.clone(),
+            self.output_links.clone(),
+            self.inputs.clone(),
+        );
+        Ok(unresolved)
+    }
+
     /// Returns `true` if all tasks for this stage are complete
     pub(super) fn is_completed(&self) -> bool {
         self.task_statuses
@@ -616,6 +638,31 @@ impl RunningStage {
         }
         first.aggregate_by_partition()
     }
+
+    /// Reset the running and completed tasks on a given executor
+    /// Returns the number of running tasks that were reset
+    pub fn reset_tasks(&mut self, executor: &str) -> usize {
+        let mut reset = 0;
+        for task in self.task_statuses.iter_mut() {
+            match task {
+                Some(task_status::Status::Running(RunningTask { executor_id }))
+                    if *executor == *executor_id =>
+                {
+                    *task = None;
+                    reset += 1;
+                }
+                Some(task_status::Status::Completed(CompletedTask {
+                    executor_id,
+                    partitions: _,
+                })) if *executor == *executor_id => {
+                    *task = None;
+                    reset += 1;
+                }
+                _ => {}
+            }
+        }
+        reset
+    }
 }
 
 impl Debug for RunningStage {
@@ -636,6 +683,48 @@ impl Debug for RunningStage {
 }
 
 impl CompletedStage {
+    pub fn to_running(&self) -> RunningStage {
+        let mut task_status: Vec<Option<task_status::Status>> = Vec::new();
+        for task in self.task_statuses.iter() {
+            match task {
+                task_status::Status::Completed(_) => 
task_status.push(Some(task.clone())),
+                _ => task_status.push(None),
+            }
+        }
+        RunningStage {
+            stage_id: self.stage_id,
+            partitions: self.partitions,
+            output_partitioning: self.output_partitioning.clone(),
+            output_links: self.output_links.clone(),
+            inputs: self.inputs.clone(),
+            plan: self.plan.clone(),
+            task_statuses: task_status,
+            stage_metrics: Some(self.stage_metrics.clone()),
+        }
+    }
+
+    /// Reset the completed tasks on a given executor
+    /// Returns the number of running tasks that were reset
+    pub fn reset_tasks(&mut self, executor: &str) -> usize {
+        let mut reset = 0;
+        let failure_reason = format!("Task failure due to Executor {} lost", 
executor);
+        for task in self.task_statuses.iter_mut() {
+            match task {
+                task_status::Status::Completed(CompletedTask {
+                    executor_id,
+                    partitions: _,
+                }) if *executor == *executor_id => {
+                    *task = task_status::Status::Failed(FailedTask {
+                        error: failure_reason.clone(),
+                    });
+                    reset += 1;
+                }
+                _ => {}
+            }
+        }
+        reset
+    }
+
     pub(super) fn decode<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan>(
         stage: protobuf::CompletedStage,
         codec: &BallistaCodec<T, U>,
@@ -654,6 +743,8 @@ impl CompletedStage {
             plan.schema().as_ref(),
         )?;
 
+        let inputs = decode_inputs(stage.inputs)?;
+
         let task_statuses = stage
             .task_statuses
             .into_iter()
@@ -676,6 +767,7 @@ impl CompletedStage {
             partitions: stage.partitions as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as 
usize).collect(),
+            inputs,
             plan,
             task_statuses,
             stage_metrics,
@@ -696,6 +788,8 @@ impl CompletedStage {
         let output_partitioning =
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
 
+        let inputs = encode_inputs(stage.inputs)?;
+
         let task_statuses: Vec<protobuf::TaskStatus> = stage
             .task_statuses
             .into_iter()
@@ -725,6 +819,7 @@ impl CompletedStage {
             partitions: stage.partitions as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as 
u32).collect(),
+            inputs,
             plan,
             task_statuses,
             stage_metrics,
@@ -894,9 +989,9 @@ impl Debug for FailedStage {
 #[derive(Clone, Debug, Default)]
 pub(super) struct StageOutput {
     /// Map from partition -> partition locations
-    partition_locations: HashMap<usize, Vec<PartitionLocation>>,
+    pub partition_locations: HashMap<usize, Vec<PartitionLocation>>,
     /// Flag indicating whether all tasks are complete
-    complete: bool,
+    pub complete: bool,
 }
 
 impl StageOutput {
@@ -926,3 +1021,61 @@ impl StageOutput {
         self.complete
     }
 }
+
+fn decode_inputs(
+    stage_inputs: Vec<GraphStageInput>,
+) -> Result<HashMap<usize, StageOutput>> {
+    let mut inputs: HashMap<usize, StageOutput> = HashMap::new();
+    for input in stage_inputs {
+        let stage_id = input.stage_id as usize;
+
+        let outputs = input
+            .partition_locations
+            .into_iter()
+            .map(|loc| {
+                let partition = loc.partition as usize;
+                let locations = loc
+                    .partition_location
+                    .into_iter()
+                    .map(|l| l.try_into())
+                    .collect::<Result<Vec<_>>>()?;
+                Ok((partition, locations))
+            })
+            .collect::<Result<HashMap<usize, Vec<PartitionLocation>>>>()?;
+
+        inputs.insert(
+            stage_id,
+            StageOutput {
+                partition_locations: outputs,
+                complete: input.complete,
+            },
+        );
+    }
+    Ok(inputs)
+}
+
+fn encode_inputs(
+    stage_inputs: HashMap<usize, StageOutput>,
+) -> Result<Vec<GraphStageInput>> {
+    let mut inputs: Vec<protobuf::GraphStageInput> = vec![];
+    for (stage_id, output) in stage_inputs.into_iter() {
+        inputs.push(protobuf::GraphStageInput {
+            stage_id: stage_id as u32,
+            partition_locations: output
+                .partition_locations
+                .into_iter()
+                .map(|(partition, locations)| {
+                    Ok(protobuf::TaskInputPartitions {
+                        partition: partition as u32,
+                        partition_location: locations
+                            .into_iter()
+                            .map(|l| l.try_into())
+                            .collect::<Result<Vec<_>>>()?,
+                    })
+                })
+                .collect::<Result<Vec<_>>>()?,
+            complete: output.complete,
+        });
+    }
+    Ok(inputs)
+}
diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs 
b/ballista/rust/scheduler/src/state/executor_manager.rs
index 4c5449c8..9fc8df90 100644
--- a/ballista/rust/scheduler/src/state/executor_manager.rs
+++ b/ballista/rust/scheduler/src/state/executor_manager.rs
@@ -24,6 +24,9 @@ use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::protobuf;
 
 use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
+use ballista_core::serde::protobuf::{
+    executor_status, ExecutorHeartbeat, ExecutorStatus,
+};
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::utils::create_grpc_client_connection;
 use futures::StreamExt;
@@ -71,11 +74,21 @@ impl ExecutorReservation {
     }
 }
 
+// TODO move to configuration file
+/// Default executor timeout in seconds, it should be longer than executor's 
heartbeat intervals.
+/// Only after missing two or tree consecutive heartbeats from a executor, the 
executor is mark
+/// to be dead.
+pub const DEFAULT_EXECUTOR_TIMEOUT_SECONDS: u64 = 180;
+
 #[derive(Clone)]
 pub(crate) struct ExecutorManager {
     state: Arc<dyn StateBackendClient>,
+    // executor_id -> ExecutorMetadata map
     executor_metadata: Arc<RwLock<HashMap<String, ExecutorMetadata>>>,
+    // executor_id -> ExecutorHeartbeat map
     executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
+    // dead executor sets:
+    dead_executors: Arc<RwLock<HashSet<String>>>,
     clients: ExecutorClients,
 }
 
@@ -85,17 +98,19 @@ impl ExecutorManager {
             state,
             executor_metadata: Arc::new(RwLock::new(HashMap::new())),
             executors_heartbeat: Arc::new(RwLock::new(HashMap::new())),
+            dead_executors: Arc::new(RwLock::new(HashSet::new())),
             clients: Default::default(),
         }
     }
 
     /// Initialize the `ExecutorManager` state. This will fill the 
`executor_heartbeats` value
-    /// with existing heartbeats. Then new updates will be consumed through 
the `ExecutorHeartbeatListener`
+    /// with existing active heartbeats. Then new updates will be consumed 
through the `ExecutorHeartbeatListener`
     pub async fn init(&self) -> Result<()> {
-        self.init_executor_heartbeats().await?;
+        self.init_active_executor_heartbeats().await?;
         let heartbeat_listener = ExecutorHeartbeatListener::new(
             self.state.clone(),
             self.executors_heartbeat.clone(),
+            self.dead_executors.clone(),
         );
         heartbeat_listener.start().await
     }
@@ -316,7 +331,10 @@ impl ExecutorManager {
         self.save_executor_heartbeat(protobuf::ExecutorHeartbeat {
             executor_id: executor_id.clone(),
             timestamp: current_ts,
-            state: None,
+            metrics: vec![],
+            status: Some(ExecutorStatus {
+                status: Some(executor_status::Status::Active("".to_string())),
+            }),
         })
         .await?;
 
@@ -341,6 +359,37 @@ impl ExecutorManager {
         }
     }
 
+    /// Remove the executor within the scheduler.
+    pub async fn remove_executor(
+        &self,
+        executor_id: &str,
+        _reason: Option<String>,
+    ) -> Result<()> {
+        let current_ts = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .map_err(|e| {
+                BallistaError::Internal(format!(
+                    "Error getting current timestamp: {:?}",
+                    e
+                ))
+            })?
+            .as_secs();
+
+        self.save_dead_executor_heartbeat(protobuf::ExecutorHeartbeat {
+            executor_id: executor_id.to_owned(),
+            timestamp: current_ts,
+            metrics: vec![],
+            status: Some(ExecutorStatus {
+                status: Some(executor_status::Status::Dead("".to_string())),
+            }),
+        })
+        .await?;
+
+        // TODO Check the Executor reservation logic for push-based scheduling
+
+        Ok(())
+    }
+
     #[cfg(not(test))]
     async fn test_scheduler_connectivity(
         &self,
@@ -383,15 +432,42 @@ impl ExecutorManager {
         Ok(())
     }
 
-    /// Initialize the set of executor heartbeats from storage
-    pub(crate) async fn init_executor_heartbeats(&self) -> Result<()> {
+    pub(crate) async fn save_dead_executor_heartbeat(
+        &self,
+        heartbeat: protobuf::ExecutorHeartbeat,
+    ) -> Result<()> {
+        let executor_id = heartbeat.executor_id.clone();
+        let value = encode_protobuf(&heartbeat)?;
+        self.state
+            .put(Keyspace::Heartbeats, executor_id.clone(), value)
+            .await?;
+
+        let mut executors_heartbeat = self.executors_heartbeat.write();
+        executors_heartbeat.remove(&heartbeat.executor_id.clone());
+
+        let mut dead_executors = self.dead_executors.write();
+        dead_executors.insert(executor_id);
+        Ok(())
+    }
+
+    pub(crate) fn is_dead_executor(&self, executor_id: &str) -> bool {
+        self.dead_executors.read().contains(executor_id)
+    }
+
+    /// Initialize the set of active executor heartbeats from storage
+    async fn init_active_executor_heartbeats(&self) -> Result<()> {
         let heartbeats = self.state.scan(Keyspace::Heartbeats, None).await?;
         let mut cache = self.executors_heartbeat.write();
 
         for (_, value) in heartbeats {
             let data: protobuf::ExecutorHeartbeat = decode_protobuf(&value)?;
             let executor_id = data.executor_id.clone();
-            cache.insert(executor_id, data);
+            if let Some(ExecutorStatus {
+                status: Some(executor_status::Status::Active(_)),
+            }) = data.status
+            {
+                cache.insert(executor_id, data);
+            }
         }
         Ok(())
     }
@@ -411,8 +487,27 @@ impl ExecutorManager {
             .collect()
     }
 
-    #[allow(dead_code)]
-    fn get_alive_executors_within_one_minute(&self) -> HashSet<String> {
+    /// Return a list of expired executors
+    pub(crate) fn get_expired_executors(&self) -> Vec<ExecutorHeartbeat> {
+        let now_epoch_ts = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .expect("Time went backwards");
+        let last_seen_threshold = now_epoch_ts
+            .checked_sub(Duration::from_secs(DEFAULT_EXECUTOR_TIMEOUT_SECONDS))
+            .unwrap_or_else(|| Duration::from_secs(0))
+            .as_secs();
+
+        let lock = self.executors_heartbeat.read();
+        let expired_executors = lock
+            .iter()
+            .filter_map(|(_exec, heartbeat)| {
+                (heartbeat.timestamp <= last_seen_threshold).then(|| 
heartbeat.clone())
+            })
+            .collect::<Vec<_>>();
+        expired_executors
+    }
+
+    pub(crate) fn get_alive_executors_within_one_minute(&self) -> 
HashSet<String> {
         let now_epoch_ts = SystemTime::now()
             .duration_since(UNIX_EPOCH)
             .expect("Time went backwards");
@@ -429,16 +524,19 @@ impl ExecutorManager {
 struct ExecutorHeartbeatListener {
     state: Arc<dyn StateBackendClient>,
     executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
+    dead_executors: Arc<RwLock<HashSet<String>>>,
 }
 
 impl ExecutorHeartbeatListener {
     pub fn new(
         state: Arc<dyn StateBackendClient>,
         executors_heartbeat: Arc<RwLock<HashMap<String, 
protobuf::ExecutorHeartbeat>>>,
+        dead_executors: Arc<RwLock<HashSet<String>>>,
     ) -> Self {
         Self {
             state,
             executors_heartbeat,
+            dead_executors,
         }
     }
 
@@ -450,6 +548,7 @@ impl ExecutorHeartbeatListener {
             .watch(Keyspace::Heartbeats, "".to_owned())
             .await?;
         let heartbeats = self.executors_heartbeat.clone();
+        let dead_executors = self.dead_executors.clone();
         tokio::task::spawn(async move {
             while let Some(event) = watch.next().await {
                 if let WatchEvent::Put(_, value) = event {
@@ -458,13 +557,20 @@ impl ExecutorHeartbeatListener {
                     {
                         let executor_id = data.executor_id.clone();
                         let mut heartbeats = heartbeats.write();
-
-                        heartbeats.insert(executor_id, data);
+                        // Remove dead executors
+                        if let Some(ExecutorStatus {
+                            status: Some(executor_status::Status::Dead(_)),
+                        }) = data.status
+                        {
+                            heartbeats.remove(&executor_id);
+                            dead_executors.write().insert(executor_id);
+                        } else {
+                            heartbeats.insert(executor_id, data);
+                        }
                     }
                 }
             }
         });
-
         Ok(())
     }
 }
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs 
b/ballista/rust/scheduler/src/state/task_manager.rs
index 8e205432..e34c04c8 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -381,6 +381,36 @@ impl<T: 'static + AsLogicalPlan, U: 'static + 
AsExecutionPlan> TaskManager<T, U>
         Ok(())
     }
 
+    pub async fn executor_lost(&self, executor_id: &str) -> Result<()> {
+        // Collect graphs we update so we can update them in storage
+        let mut updated_graphs: HashMap<String, ExecutionGraph> = 
HashMap::new();
+        {
+            let job_cache = self.active_job_cache.read().await;
+            for (job_id, graph) in job_cache.iter() {
+                let mut graph = graph.write().await;
+                let reset = graph.reset_stages(executor_id)?;
+                if !reset.is_empty() {
+                    updated_graphs.insert(job_id.to_owned(), graph.clone());
+                }
+            }
+        }
+
+        let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
+        with_lock(lock, async {
+            // Transactional update graphs
+            let txn_ops: Vec<(Keyspace, String, Vec<u8>)> = updated_graphs
+                .into_iter()
+                .map(|(job_id, graph)| {
+                    let value = self.encode_execution_graph(graph)?;
+                    Ok((Keyspace::ActiveJobs, job_id, value))
+                })
+                .collect::<Result<Vec<_>>>()?;
+            self.state.put_txn(txn_ops).await?;
+            Ok(())
+        })
+        .await
+    }
+
     #[cfg(not(test))]
     /// Launch the given task on the specified executor
     pub(crate) async fn launch_task(

Reply via email to