You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by nj...@apache.org on 2022/10/02 09:50:34 UTC

[arrow-ballista] branch master updated: Task level retry and Stage level retry (#261)

This is an automated email from the ASF dual-hosted git repository.

nju_yaho pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git


The following commit(s) were added to refs/heads/master by this push:
     new f5bfef00 Task level retry and Stage level retry (#261)
f5bfef00 is described below

commit f5bfef00bcb695c68f377bdd23fa3efdefa1f43c
Author: mingmwang <mi...@ebay.com>
AuthorDate: Sun Oct 2 17:50:30 2022 +0800

    Task level retry and Stage level retry (#261)
    
    * Task level failure retry and Stage level failure retry
    
    * Add UT
    
    * fix fmt
    
    * Resolve review comments
---
 ballista/rust/core/proto/ballista.proto            |  165 +-
 ballista/rust/core/src/client.rs                   |   42 +-
 ballista/rust/core/src/error.rs                    |   85 +-
 .../core/src/execution_plans/distributed_query.rs  |   11 +-
 .../core/src/execution_plans/shuffle_reader.rs     |  219 ++-
 .../rust/core/src/serde/scheduler/from_proto.rs    |   61 +-
 ballista/rust/core/src/serde/scheduler/mod.rs      |   17 +-
 ballista/rust/core/src/serde/scheduler/to_proto.rs |   26 +-
 ballista/rust/core/src/utils.rs                    |    7 +-
 ballista/rust/executor/src/execution_loop.rs       |   67 +-
 ballista/rust/executor/src/executor.rs             |   60 +-
 ballista/rust/executor/src/executor_server.rs      |  109 +-
 ballista/rust/executor/src/lib.rs                  |   41 +-
 ballista/rust/scheduler/src/api/handlers.rs        |    2 +-
 ballista/rust/scheduler/src/flight_sql.rs          |    8 +-
 .../rust/scheduler/src/scheduler_server/event.rs   |    4 +-
 .../rust/scheduler/src/scheduler_server/grpc.rs    |   22 +-
 .../rust/scheduler/src/scheduler_server/mod.rs     |   68 +-
 .../src/scheduler_server/query_stage_scheduler.rs  |   46 +-
 .../rust/scheduler/src/state/execution_graph.rs    | 2042 +++++++++++++++++---
 .../src/state/execution_graph/execution_stage.rs   |  565 ++++--
 .../scheduler/src/state/execution_graph_dot.rs     |    2 +-
 .../rust/scheduler/src/state/executor_manager.rs   |   46 +-
 ballista/rust/scheduler/src/state/mod.rs           |   74 +-
 ballista/rust/scheduler/src/state/task_manager.rs  |  226 +--
 25 files changed, 3058 insertions(+), 957 deletions(-)

diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto
index 8923cbc9..84cecffa 100644
--- a/ballista/rust/core/proto/ballista.proto
+++ b/ballista/rust/core/proto/ballista.proto
@@ -434,54 +434,87 @@ message ExecutionGraph {
   uint64 output_partitions = 5;
   repeated PartitionLocation output_locations = 6;
   string scheduler_id = 7;
+  uint32 task_id_gen = 8;
+  repeated StageAttempts failed_attempts = 9;
+}
+
+message StageAttempts {
+  uint32 stage_id = 1;
+  repeated uint32 stage_attempt_num = 2;
 }
 
 message ExecutionGraphStage {
   oneof StageType {
       UnResolvedStage unresolved_stage = 1;
       ResolvedStage resolved_stage = 2;
-      CompletedStage completed_stage = 3;
+      SuccessfulStage successful_stage = 3;
       FailedStage failed_stage = 4;
   }
 }
 
 message UnResolvedStage {
-  uint64 stage_id = 1;
+  uint32 stage_id = 1;
   PhysicalHashRepartition output_partitioning = 2;
   repeated uint32 output_links = 3;
   repeated  GraphStageInput inputs = 4;
   bytes plan = 5;
+  uint32 stage_attempt_num = 6;
+  repeated string last_attempt_failure_reasons = 7;
 }
 
 message ResolvedStage {
-  uint64 stage_id = 1;
+  uint32 stage_id = 1;
   uint32 partitions = 2;
   PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   repeated  GraphStageInput inputs = 5;
   bytes plan = 6;
+  uint32 stage_attempt_num = 7;
+  repeated string last_attempt_failure_reasons = 8;
 }
 
-message CompletedStage {
-  uint64 stage_id = 1;
+message SuccessfulStage {
+  uint32 stage_id = 1;
   uint32 partitions = 2;
   PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   repeated  GraphStageInput inputs = 5;
   bytes plan = 6;
-  repeated TaskStatus task_statuses = 7;
+  repeated TaskInfo task_infos = 7;
   repeated OperatorMetricsSet stage_metrics = 8;
+  uint32 stage_attempt_num = 9;
 }
 
 message FailedStage {
-  uint64 stage_id = 1;
+  uint32 stage_id = 1;
   uint32 partitions = 2;
   PhysicalHashRepartition output_partitioning = 3;
   repeated uint32 output_links = 4;
   bytes plan = 5;
-  repeated TaskStatus task_statuses = 6;
+  repeated TaskInfo task_infos = 6;
   repeated OperatorMetricsSet stage_metrics = 7;
   string error_message = 8;
+  uint32 stage_attempt_num = 9;
+}
+
+message TaskInfo {
+  uint32 task_id = 1;
+  uint32 partition_id = 2;
+  // Scheduler schedule time
+  uint64 scheduled_time = 3;
+  // Scheduler launch time
+  uint64 launch_time = 4;
+  // The time the Executor start to run the task
+  uint64 start_exec_time = 5;
+  // The time the Executor finish the task
+  uint64 end_exec_time = 6;
+  // Scheduler side finish time
+  uint64 finish_time = 7;
+  oneof status {
+    RunningTask running = 8;
+    FailedTask failed = 9;
+    SuccessfulTask successful = 10;
+  }
 }
 
 message GraphStageInput {
@@ -531,12 +564,14 @@ message FetchPartition {
   uint32 port = 6;
 }
 
-// Mapping from partition id to executor id
 message PartitionLocation {
-  PartitionId partition_id = 1;
-  ExecutorMetadata executor_meta = 2;
-  PartitionStats partition_stats = 3;
-  string path = 4;
+  // partition_id of the map stage who produces the shuffle.
+  uint32 map_partition_id = 1;
+  // partition_id of the shuffle, a composition of(job_id + map_stage_id + partition_id).
+  PartitionId partition_id = 2;
+  ExecutorMetadata executor_meta = 3;
+  PartitionStats partition_stats = 4;
+  string path = 5;
 }
 
 // Unique identifier for a materialized partition of data
@@ -546,11 +581,10 @@ message PartitionId {
   uint32 partition_id = 4;
 }
 
-// Multiple partitions in the same stage
-message PartitionIds {
-  string job_id = 1;
-  uint32 stage_id = 2;
-  repeated uint32 partition_ids = 4;
+message TaskId {
+  uint32 task_id = 1;
+  uint32 task_attempt_num = 2;
+  uint32 partition_id = 3;
 }
 
 message PartitionStats {
@@ -674,15 +708,48 @@ message RunningTask {
 
 message FailedTask {
   string error = 1;
+  bool retryable = 2;
+  // Whether this task failure should be counted to the maximum number of times the task is allowed to retry
+  bool count_to_failures = 3;
+  oneof failed_reason {
+    ExecutionError execution_error = 4;
+    FetchPartitionError fetch_partition_error = 5;
+    IOError io_error = 6;
+    ExecutorLost executor_lost = 7;
+    // A successful task's result is lost due to executor lost
+    ResultLost result_lost = 8;
+    TaskKilled task_killed = 9;
+  }
 }
 
-message CompletedTask {
+message SuccessfulTask {
   string executor_id = 1;
   // TODO tasks are currently always shuffle writes but this will not always be the case
   // so we might want to think about some refactoring of the task definitions
   repeated ShuffleWritePartition partitions = 2;
 }
 
+message ExecutionError {
+}
+
+message FetchPartitionError {
+  string executor_id = 1;
+  uint32 map_stage_id = 2;
+  uint32 map_partition_id = 3;
+}
+
+message IOError {
+}
+
+message ExecutorLost {
+}
+
+message ResultLost {
+}
+
+message TaskKilled {
+}
+
 message ShuffleWritePartition {
   uint64 partition_id = 1;
   string path = 2;
@@ -692,13 +759,20 @@ message ShuffleWritePartition {
 }
 
 message TaskStatus {
-  PartitionId task_id = 1;
+  uint32 task_id = 1;
+  string job_id = 2;
+  uint32 stage_id = 3;
+  uint32 stage_attempt_num = 4;
+  uint32 partition_id = 5;
+  uint64 launch_time = 6;
+  uint64 start_exec_time = 7;
+  uint64 end_exec_time = 8;
   oneof status {
-    RunningTask running = 2;
-    FailedTask failed = 3;
-    CompletedTask completed = 4;
+    RunningTask running = 9;
+    FailedTask failed = 10;
+    SuccessfulTask successful = 11;
   }
-  repeated OperatorMetricsSet metrics = 5;
+  repeated OperatorMetricsSet metrics = 12;
 }
 
 message PollWorkParams {
@@ -709,22 +783,32 @@ message PollWorkParams {
 }
 
 message TaskDefinition {
-  PartitionId task_id = 1;
-  bytes plan = 2;
+  uint32 task_id = 1;
+  uint32 task_attempt_num = 2;
+  string job_id = 3;
+  uint32 stage_id = 4;
+  uint32 stage_attempt_num = 5;
+  uint32 partition_id = 6;
+  bytes plan = 7;
   // Output partition for shuffle writer
-  PhysicalHashRepartition output_partitioning = 3;
-  string session_id = 4;
-  repeated KeyValuePair props = 5;
+  PhysicalHashRepartition output_partitioning = 8;
+  string session_id = 9;
+  uint64 launch_time = 10;
+  repeated KeyValuePair props = 11;
 }
 
 // A set of tasks in the same stage
 message MultiTaskDefinition {
-  PartitionIds task_ids = 1;
-  bytes plan = 2;
+  repeated TaskId task_ids = 1;
+  string job_id = 2;
+  uint32 stage_id = 3;
+  uint32 stage_attempt_num = 4;
+  bytes plan = 5;
   // Output partition for shuffle writer
-  PhysicalHashRepartition output_partitioning = 3;
-  string session_id = 4;
-  repeated KeyValuePair props = 5;
+  PhysicalHashRepartition output_partitioning = 6;
+  string session_id = 7;
+  uint64 launch_time = 8;
+  repeated KeyValuePair props = 9;
 }
 
 message SessionSettings {
@@ -812,7 +896,7 @@ message GetJobStatusParams {
   string job_id = 1;
 }
 
-message CompletedJob {
+message SuccessfulJob {
   repeated PartitionLocation partition_location = 1;
 }
 
@@ -830,7 +914,7 @@ message JobStatus {
     QueuedJob queued = 1;
     RunningJob running = 2;
     FailedJob failed = 3;
-    CompletedJob completed = 4;
+    SuccessfulJob successful = 4;
   }
 }
 
@@ -882,13 +966,20 @@ message LaunchMultiTaskResult {
 }
 
 message CancelTasksParams {
-  repeated PartitionId partition_id = 1;
+  repeated RunningTaskInfo task_infos = 1;
 }
 
 message CancelTasksResult {
   bool cancelled = 1;
 }
 
+message RunningTaskInfo {
+  uint32 task_id = 1;
+  string job_id = 2;
+  uint32 stage_id = 3;
+  uint32 partition_id = 4;;
+}
+
 service SchedulerGrpc {
   // Executors must poll the scheduler for heartbeat and to receive tasks
   rpc PollWork (PollWorkParams) returns (PollWorkResult) {}
diff --git a/ballista/rust/core/src/client.rs b/ballista/rust/core/src/client.rs
index 13276f1b..460254c4 100644
--- a/ballista/rust/core/src/client.rs
+++ b/ballista/rust/core/src/client.rs
@@ -25,8 +25,8 @@ use std::{
     task::{Context, Poll},
 };
 
-use crate::error::{ballista_error, BallistaError, Result};
-use crate::serde::scheduler::Action;
+use crate::error::{BallistaError, Result};
+use crate::serde::scheduler::{Action, PartitionId};
 
 use arrow_flight::utils::flight_data_to_arrow_batch;
 use arrow_flight::Ticket;
@@ -62,7 +62,7 @@ impl BallistaClient {
             create_grpc_client_connection(addr.clone())
                 .await
                 .map_err(|e| {
-                    BallistaError::General(format!(
+                    BallistaError::GrpcConnectionError(format!(
                         "Error connecting to Ballista scheduler or executor at {}: {:?}",
                         addr, e
                     ))
@@ -76,22 +76,32 @@ impl BallistaClient {
     /// Fetch a partition from an executor
     pub async fn fetch_partition(
         &mut self,
-        job_id: &str,
-        stage_id: usize,
-        partition_id: usize,
+        executor_id: &str,
+        partition_id: &PartitionId,
         path: &str,
         host: &str,
         port: u16,
     ) -> Result<SendableRecordBatchStream> {
         let action = Action::FetchPartition {
-            job_id: job_id.to_string(),
-            stage_id,
-            partition_id,
+            job_id: partition_id.job_id.clone(),
+            stage_id: partition_id.stage_id,
+            partition_id: partition_id.partition_id,
             path: path.to_owned(),
-            host: host.to_string(),
+            host: host.to_owned(),
             port,
         };
-        self.execute_action(&action).await
+        self.execute_action(&action)
+            .await
+            .map_err(|error| match error {
+                // map grpc connection error to partition fetch error.
+                BallistaError::GrpcActionError(msg) => BallistaError::FetchFailed(
+                    executor_id.to_owned(),
+                    partition_id.stage_id,
+                    partition_id.partition_id,
+                    msg,
+                ),
+                other => other,
+            })
     }
 
     /// Execute an action and retrieve the results
@@ -105,7 +115,7 @@ impl BallistaClient {
 
         serialized_action
             .encode(&mut buf)
-            .map_err(|e| BallistaError::General(format!("{:?}", e)))?;
+            .map_err(|e| BallistaError::GrpcActionError(format!("{:?}", e)))?;
 
         let request = tonic::Request::new(Ticket { ticket: buf });
 
@@ -113,14 +123,14 @@ impl BallistaClient {
             .flight_client
             .do_get(request)
             .await
-            .map_err(|e| BallistaError::General(format!("{:?}", e)))?
+            .map_err(|e| BallistaError::GrpcActionError(format!("{:?}", e)))?
             .into_inner();
 
         // the schema should be the first message returned, else client should error
         match stream
             .message()
             .await
-            .map_err(|e| BallistaError::General(format!("{:?}", e)))?
+            .map_err(|e| BallistaError::GrpcActionError(format!("{:?}", e)))?
         {
             Some(flight_data) => {
                 // convert FlightData to a stream
@@ -129,8 +139,8 @@ impl BallistaClient {
                 // all the remaining stream messages should be dictionary and record batches
                 Ok(Box::pin(FlightDataStream::new(stream, schema)))
             }
-            None => Err(ballista_error(
-                "Did not receive schema batch from flight server",
+            None => Err(BallistaError::GrpcActionError(
+                "Did not receive schema batch from flight server".to_string(),
             )),
         }
     }
diff --git a/ballista/rust/core/src/error.rs b/ballista/rust/core/src/error.rs
index dd4f9dda..53d779e1 100644
--- a/ballista/rust/core/src/error.rs
+++ b/ballista/rust/core/src/error.rs
@@ -23,6 +23,8 @@ use std::{
     io, result,
 };
 
+use crate::serde::protobuf::failed_task::FailedReason;
+use crate::serde::protobuf::{ExecutionError, FailedTask, FetchPartitionError, IoError};
 use datafusion::arrow::error::ArrowError;
 use datafusion::error::DataFusionError;
 use futures::future::Aborted;
@@ -47,7 +49,11 @@ pub enum BallistaError {
     // KubeAPIResponseError(k8s_openapi::ResponseError),
     TonicError(tonic::transport::Error),
     GrpcError(tonic::Status),
+    GrpcConnectionError(String),
     TokioError(tokio::task::JoinError),
+    GrpcActionError(String),
+    // (executor_id, map_stage_id, map_partition_id, message)
+    FetchFailed(String, usize, usize, String),
     Cancelled,
 }
 
@@ -70,7 +76,19 @@ impl From<String> for BallistaError {
 
 impl From<ArrowError> for BallistaError {
     fn from(e: ArrowError) -> Self {
-        BallistaError::ArrowError(e)
+        match e {
+            ArrowError::ExternalError(e)
+                if e.downcast_ref::<BallistaError>().is_some() =>
+            {
+                *e.downcast::<BallistaError>().unwrap()
+            }
+            ArrowError::ExternalError(e)
+                if e.downcast_ref::<DataFusionError>().is_some() =>
+            {
+                BallistaError::DataFusionError(*e.downcast::<DataFusionError>().unwrap())
+            }
+            other => BallistaError::ArrowError(other),
+        }
     }
 }
 
@@ -182,13 +200,78 @@ impl Display for BallistaError {
             // }
             BallistaError::TonicError(desc) => write!(f, "Tonic error: {}", desc),
             BallistaError::GrpcError(desc) => write!(f, "Grpc error: {}", desc),
+            BallistaError::GrpcConnectionError(desc) => {
+                write!(f, "Grpc connection error: {}", desc)
+            }
             BallistaError::Internal(desc) => {
                 write!(f, "Internal Ballista error: {}", desc)
             }
             BallistaError::TokioError(desc) => write!(f, "Tokio join error: {}", desc),
+            BallistaError::GrpcActionError(desc) => {
+                write!(f, "Grpc Execute Action error: {}", desc)
+            }
+            BallistaError::FetchFailed(executor_id, map_stage, map_partition, desc) => {
+                write!(
+                    f,
+                    "Shuffle fetch partition error from Executor {}, map_stage {}, \
+                map_partition {}, error desc: {}",
+                    executor_id, map_stage, map_partition, desc
+                )
+            }
             BallistaError::Cancelled => write!(f, "Task cancelled"),
         }
     }
 }
 
+impl From<BallistaError> for FailedTask {
+    fn from(e: BallistaError) -> Self {
+        match e {
+            BallistaError::FetchFailed(
+                executor_id,
+                map_stage_id,
+                map_partition_id,
+                desc,
+            ) => {
+                FailedTask {
+                    error: desc,
+                    // fetch partition error is considered to be non-retryable
+                    retryable: false,
+                    count_to_failures: false,
+                    failed_reason: Some(FailedReason::FetchPartitionError(
+                        FetchPartitionError {
+                            executor_id,
+                            map_stage_id: map_stage_id as u32,
+                            map_partition_id: map_partition_id as u32,
+                        },
+                    )),
+                }
+            }
+            BallistaError::IoError(io) => {
+                FailedTask {
+                    error: format!("Task failed due to Ballista IO error: {:?}", io),
+                    // IO error is considered to be temporary and retryable
+                    retryable: true,
+                    count_to_failures: true,
+                    failed_reason: Some(FailedReason::IoError(IoError {})),
+                }
+            }
+            BallistaError::DataFusionError(DataFusionError::IoError(io)) => {
+                FailedTask {
+                    error: format!("Task failed due to DataFusion IO error: {:?}", io),
+                    // IO error is considered to be temporary and retryable
+                    retryable: true,
+                    count_to_failures: true,
+                    failed_reason: Some(FailedReason::IoError(IoError {})),
+                }
+            }
+            other => FailedTask {
+                error: format!("Task failed due to runtime execution error: {:?}", other),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(FailedReason::ExecutionError(ExecutionError {})),
+            },
+        }
+    }
+}
+
 impl Error for BallistaError {}
diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs
index 67393c20..bf5dc0cf 100644
--- a/ballista/rust/core/src/execution_plans/distributed_query.rs
+++ b/ballista/rust/core/src/execution_plans/distributed_query.rs
@@ -294,8 +294,8 @@ async fn execute_query(
                 error!("{}", msg);
                 break Err(DataFusionError::Execution(msg));
             }
-            Some(job_status::Status::Completed(completed)) => {
-                let streams = completed.partition_location.into_iter().map(|p| {
+            Some(job_status::Status::Successful(successful)) => {
+                let streams = successful.partition_location.into_iter().map(|p| {
                     let f = fetch_partition(p)
                         .map_err(|e| ArrowError::ExternalError(Box::new(e)));
 
@@ -324,13 +324,12 @@ async fn fetch_partition(
         .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
     ballista_client
         .fetch_partition(
-            &partition_id.job_id,
-            partition_id.stage_id as usize,
-            partition_id.partition_id as usize,
+            &metadata.id,
+            &partition_id.into(),
             &location.path,
             host,
             port,
         )
         .await
-        .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))
+        .map_err(|e| DataFusionError::External(Box::new(e)))
 }
diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
index d0d9a28b..e13e578e 100644
--- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs
+++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs
@@ -15,15 +15,17 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use async_trait::async_trait;
 use std::any::Any;
 use std::collections::HashMap;
+use std::result;
 use std::sync::Arc;
 
-#[cfg(not(test))]
 use crate::client::BallistaClient;
 use crate::serde::scheduler::{PartitionLocation, PartitionStats};
 
 use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::error::ArrowError;
 
 use datafusion::error::{DataFusionError, Result};
 use datafusion::physical_plan::expressions::PhysicalSortExpr;
@@ -33,6 +35,7 @@ use datafusion::physical_plan::{
 };
 use futures::{Stream, StreamExt, TryStreamExt};
 
+use crate::error::BallistaError;
 use datafusion::execution::context::TaskContext;
 use datafusion::physical_plan::common::AbortOnDropMany;
 use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
@@ -131,8 +134,10 @@ impl ExecutionPlan for ShuffleReaderExec {
             .collect();
         // Shuffle partitions for evenly send fetching partition requests to avoid hot executors within multiple tasks
         partition_locations.shuffle(&mut thread_rng());
+
+        let partition_reader = FlightPartitionReader {};
         let response_receiver =
-            send_fetch_partitions(partition_locations, max_request_num);
+            send_fetch_partitions(partition_locations, max_request_num, partition_reader);
 
         let result = RecordBatchStreamAdapter::new(
             Arc::new(self.schema.as_ref().clone()),
@@ -192,7 +197,7 @@ fn stats_for_partitions(
 
 /// Adapter for a tokio ReceiverStream that implements the SendableRecordBatchStream interface
 struct AbortableReceiverStream {
-    inner: ReceiverStream<Result<SendableRecordBatchStream>>,
+    inner: ReceiverStream<result::Result<SendableRecordBatchStream, BallistaError>>,
 
     #[allow(dead_code)]
     drop_helper: AbortOnDropMany<()>,
@@ -201,7 +206,9 @@ struct AbortableReceiverStream {
 impl AbortableReceiverStream {
     /// Construct a new SendableRecordBatchReceiverStream which will send batches of the specified schema from inner
     pub fn create(
-        rx: tokio::sync::mpsc::Receiver<Result<SendableRecordBatchStream>>,
+        rx: tokio::sync::mpsc::Receiver<
+            result::Result<SendableRecordBatchStream, BallistaError>,
+        >,
         join_handles: Vec<JoinHandle<()>>,
     ) -> AbortableReceiverStream {
         let inner = ReceiverStream::new(rx);
@@ -213,19 +220,22 @@ impl AbortableReceiverStream {
 }
 
 impl Stream for AbortableReceiverStream {
-    type Item = Result<SendableRecordBatchStream>;
+    type Item = result::Result<SendableRecordBatchStream, ArrowError>;
 
     fn poll_next(
         mut self: std::pin::Pin<&mut Self>,
         cx: &mut std::task::Context<'_>,
     ) -> std::task::Poll<Option<Self::Item>> {
-        self.inner.poll_next_unpin(cx)
+        self.inner
+            .poll_next_unpin(cx)
+            .map_err(|e| ArrowError::ExternalError(Box::new(e)))
     }
 }
 
-fn send_fetch_partitions(
+fn send_fetch_partitions<R: PartitionReader + 'static>(
     partition_locations: Vec<PartitionLocation>,
     max_request_num: usize,
+    partition_reader: R,
 ) -> AbortableReceiverStream {
     let (response_sender, response_receiver) = mpsc::channel(max_request_num);
     let semaphore = Arc::new(Semaphore::new(max_request_num));
@@ -233,10 +243,11 @@ fn send_fetch_partitions(
     for p in partition_locations.into_iter() {
         let semaphore = semaphore.clone();
         let response_sender = response_sender.clone();
+        let partition_reader_clone = partition_reader.clone();
         let join_handle = tokio::spawn(async move {
             // Block if exceeds max request number
             let permit = semaphore.acquire_owned().await.unwrap();
-            let r = fetch_partition(&p).await;
+            let r = partition_reader_clone.fetch_partition(&p).await;
             // Block if the channel buffer is full
             if let Err(e) = response_sender.send(r).await {
                 error!("Fail to send response event to the channel due to {}", e);
@@ -250,48 +261,72 @@ fn send_fetch_partitions(
     AbortableReceiverStream::create(response_receiver, join_handles)
 }
 
-#[cfg(not(test))]
-async fn fetch_partition(
-    location: &PartitionLocation,
-) -> Result<SendableRecordBatchStream> {
-    let metadata = &location.executor_meta;
-    let partition_id = &location.partition_id;
-    // TODO for shuffle client connections, we should avoid creating new connections again and again.
-    // And we should also avoid to keep alive too many connections for long time.
-    let host = metadata.host.as_str();
-    let port = metadata.port as u16;
-    let mut ballista_client = BallistaClient::try_new(host, port)
-        .await
-        .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?;
-    ballista_client
-        .fetch_partition(
-            &partition_id.job_id,
-            partition_id.stage_id as usize,
-            partition_id.partition_id as usize,
-            &location.path,
-            host,
-            port,
-        )
-        .await
-        .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))
+/// Partition reader Trait, different partition reader can have
+#[async_trait]
+trait PartitionReader: Send + Sync + Clone {
+    // Read partition data from PartitionLocation
+    async fn fetch_partition(
+        &self,
+        location: &PartitionLocation,
+    ) -> result::Result<SendableRecordBatchStream, BallistaError>;
 }
 
-#[cfg(test)]
-async fn fetch_partition(
-    location: &PartitionLocation,
-) -> Result<SendableRecordBatchStream> {
-    tests::fetch_test_partition(location)
+#[derive(Clone)]
+struct FlightPartitionReader {}
+
+#[async_trait]
+impl PartitionReader for FlightPartitionReader {
+    async fn fetch_partition(
+        &self,
+        location: &PartitionLocation,
+    ) -> result::Result<SendableRecordBatchStream, BallistaError> {
+        let metadata = &location.executor_meta;
+        let partition_id = &location.partition_id;
+        // TODO for shuffle client connections, we should avoid creating new connections again and again.
+        // And we should also avoid to keep alive too many connections for long time.
+        let host = metadata.host.as_str();
+        let port = metadata.port as u16;
+        let mut ballista_client =
+            BallistaClient::try_new(host, port)
+                .await
+                .map_err(|error| match error {
+                    // map grpc connection error to partition fetch error.
+                    BallistaError::GrpcConnectionError(msg) => {
+                        BallistaError::FetchFailed(
+                            metadata.id.clone(),
+                            partition_id.stage_id,
+                            partition_id.partition_id,
+                            msg,
+                        )
+                    }
+                    other => other,
+                })?;
+
+        ballista_client
+            .fetch_partition(&metadata.id, partition_id, &location.path, host, port)
+            .await
+    }
 }
 
+#[allow(dead_code)]
+// TODO
+struct LocalPartitionReader {}
+
+#[allow(dead_code)]
+// TODO
+struct ObjectStorePartitionReader {}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::serde::scheduler::{ExecutorMetadata, ExecutorSpecification, PartitionId};
+    use crate::utils;
     use datafusion::arrow::array::Int32Array;
     use datafusion::arrow::datatypes::{DataType, Field, Schema};
     use datafusion::arrow::record_batch::RecordBatch;
     use datafusion::physical_plan::common;
     use datafusion::physical_plan::stream::RecordBatchReceiverStream;
+    use datafusion::prelude::SessionContext;
 
     #[tokio::test]
     async fn test_stats_for_partitions_empty() {
@@ -361,6 +396,55 @@ mod tests {
         assert_eq!(result, exptected);
     }
 
+    #[tokio::test]
+    async fn test_fetch_partitions_error_mapping() -> Result<()> {
+        let session_ctx = SessionContext::new();
+        let task_ctx = session_ctx.task_ctx();
+        let schema = Schema::new(vec![
+            Field::new("a", DataType::Int32, false),
+            Field::new("b", DataType::Int32, false),
+            Field::new("c", DataType::Int32, false),
+        ]);
+
+        let job_id = "test_job_1";
+        let mut partitions: Vec<PartitionLocation> = vec![];
+        for partition_id in 0..4 {
+            partitions.push(PartitionLocation {
+                map_partition_id: 0,
+                partition_id: PartitionId {
+                    job_id: job_id.to_string(),
+                    stage_id: 2,
+                    partition_id,
+                },
+                executor_meta: ExecutorMetadata {
+                    id: "executor_1".to_string(),
+                    host: "executor_1".to_string(),
+                    port: 7070,
+                    grpc_port: 8080,
+                    specification: ExecutorSpecification { task_slots: 1 },
+                },
+                partition_stats: Default::default(),
+                path: "test_path".to_string(),
+            })
+        }
+
+        let shuffle_reader_exec =
+            ShuffleReaderExec::try_new(vec![partitions], Arc::new(schema))?;
+        let mut stream = shuffle_reader_exec.execute(0, task_ctx)?;
+        let batches = utils::collect_stream(&mut stream).await;
+
+        assert!(batches.is_err());
+
+        // BallistaError::FetchFailed -> ArrowError::ExternalError -> ballistaError::FetchFailed
+        let ballista_error = batches.unwrap_err();
+        assert!(matches!(
+            ballista_error,
+            BallistaError::FetchFailed(_, _, _, _)
+        ));
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_send_fetch_partitions_1() {
         test_send_fetch_partitions(1, 10).await;
@@ -374,8 +458,9 @@ mod tests {
     async fn test_send_fetch_partitions(max_request_num: usize, partition_num: usize) {
         let schema = get_test_partition_schema();
         let partition_locations = get_test_partition_locations(partition_num);
+        let partition_reader = MockPartitionReader {};
         let response_receiver =
-            send_fetch_partitions(partition_locations, max_request_num);
+            send_fetch_partitions(partition_locations, max_request_num, partition_reader);
 
         let stream = RecordBatchStreamAdapter::new(
             Arc::new(schema),
@@ -390,6 +475,7 @@ mod tests {
         (0..n)
             .into_iter()
             .map(|partition_id| PartitionLocation {
+                map_partition_id: 0,
                 partition_id: PartitionId {
                     job_id: "job".to_string(),
                     stage_id: 1,
@@ -408,31 +494,40 @@ mod tests {
             .collect()
     }
 
-    pub(crate) fn fetch_test_partition(
-        location: &PartitionLocation,
-    ) -> Result<SendableRecordBatchStream> {
-        let id_array = Int32Array::from(vec![location.partition_id.partition_id as i32]);
-        let schema = Arc::new(get_test_partition_schema());
-
-        let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(id_array)])?;
-
-        let (tx, rx) = tokio::sync::mpsc::channel(2);
-
-        // task simply sends data in order but in a separate
-        // thread (to ensure the batches are not available without the
-        // DelayedStream yielding).
-        let join_handle = tokio::task::spawn(async move {
-            println!("Sending batch via delayed stream");
-            if let Err(e) = tx.send(Ok(batch)).await {
-                println!("ERROR batch via delayed stream: {}", e);
-            }
-        });
-
-        // returned stream simply reads off the rx stream
-        Ok(RecordBatchReceiverStream::create(&schema, rx, join_handle))
-    }
-
     fn get_test_partition_schema() -> Schema {
         Schema::new(vec![Field::new("id", DataType::Int32, false)])
     }
+
+    #[derive(Clone)]
+    struct MockPartitionReader {}
+
+    #[async_trait]
+    impl PartitionReader for MockPartitionReader {
+        async fn fetch_partition(
+            &self,
+            location: &PartitionLocation,
+        ) -> result::Result<SendableRecordBatchStream, BallistaError> {
+            let id_array =
+                Int32Array::from(vec![location.partition_id.partition_id as i32]);
+            let schema =
+                Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
+
+            let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(id_array)])?;
+
+            let (tx, rx) = tokio::sync::mpsc::channel(2);
+
+            // task simply sends data in order but in a separate
+            // thread (to ensure the batches are not available without the
+            // DelayedStream yielding).
+            let join_handle = tokio::task::spawn(async move {
+                println!("Sending batch via delayed stream");
+                if let Err(e) = tx.send(Ok(batch)).await {
+                    println!("ERROR batch via delayed stream: {}", e);
+                }
+            });
+
+            // returned stream simply reads off the rx stream
+            Ok(RecordBatchReceiverStream::create(&schema, rx, join_handle))
+        }
+    }
 }
diff --git a/ballista/rust/core/src/serde/scheduler/from_proto.rs b/ballista/rust/core/src/serde/scheduler/from_proto.rs
index cfe0cbbf..a9e0e66b 100644
--- a/ballista/rust/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/from_proto.rs
@@ -31,7 +31,7 @@ use crate::serde::protobuf::action::ActionType;
 use crate::serde::protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};
 use crate::serde::scheduler::{
     Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
-    PartitionIds, PartitionLocation, PartitionStats, TaskDefinition,
+    PartitionLocation, PartitionStats, TaskDefinition,
 };
 
 impl TryInto<Action> for protobuf::Action {
@@ -89,6 +89,7 @@ impl TryInto<PartitionLocation> for protobuf::PartitionLocation {
 
     fn try_into(self) -> Result<PartitionLocation, Self::Error> {
         Ok(PartitionLocation {
+            map_partition_id: self.map_partition_id as usize,
             partition_id: self
                 .partition_id
                 .ok_or_else(|| {
@@ -266,21 +267,6 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
     }
 }
 
-#[allow(clippy::from_over_into)]
-impl Into<PartitionIds> for protobuf::PartitionIds {
-    fn into(self) -> PartitionIds {
-        PartitionIds {
-            job_id: self.job_id.clone(),
-            stage_id: self.stage_id as usize,
-            partition_ids: self
-                .partition_ids
-                .into_iter()
-                .map(|partition_id| partition_id as usize)
-                .collect(),
-        }
-    }
-}
-
 impl TryInto<TaskDefinition> for protobuf::TaskDefinition {
     type Error = BallistaError;
 
@@ -290,18 +276,17 @@ impl TryInto<TaskDefinition> for protobuf::TaskDefinition {
             props.insert(kv_pair.key, kv_pair.value);
         }
 
-        let task_id = self
-            .task_id
-            .ok_or_else(|| {
-                BallistaError::General("No task id in the TaskDefinition".to_owned())
-            })?
-            .into();
-
         Ok(TaskDefinition {
-            task_id,
+            task_id: self.task_id as usize,
+            task_attempt_num: self.task_attempt_num as usize,
+            job_id: self.job_id,
+            stage_id: self.stage_id as usize,
+            stage_attempt_num: self.stage_attempt_num as usize,
+            partition_id: self.partition_id as usize,
             plan: self.plan,
             output_partitioning: self.output_partitioning,
             session_id: self.session_id,
+            launch_time: self.launch_time,
             props,
         })
     }
@@ -319,27 +304,25 @@ impl TryInto<Vec<TaskDefinition>> for protobuf::MultiTaskDefinition {
         let plan = self.plan;
         let output_partitioning = self.output_partitioning;
         let session_id = self.session_id;
-        let task_ids: PartitionIds = self
-            .task_ids
-            .ok_or_else(|| {
-                BallistaError::General(
-                    "No task ids in the MultiTaskDefinition".to_owned(),
-                )
-            })?
-            .into();
+        let job_id = self.job_id;
+        let stage_id = self.stage_id as usize;
+        let stage_attempt_num = self.stage_attempt_num as usize;
+        let launch_time = self.launch_time;
+        let task_ids = self.task_ids;
 
         Ok(task_ids
-            .partition_ids
             .iter()
-            .map(|partition_id| TaskDefinition {
-                task_id: PartitionId::new(
-                    &task_ids.job_id,
-                    task_ids.stage_id,
-                    *partition_id,
-                ),
+            .map(|task_id| TaskDefinition {
+                task_id: task_id.task_id as usize,
+                task_attempt_num: task_id.task_attempt_num as usize,
+                job_id: job_id.clone(),
+                stage_id,
+                stage_attempt_num,
+                partition_id: task_id.partition_id as usize,
                 plan: plan.clone(),
                 output_partitioning: output_partitioning.clone(),
                 session_id: session_id.clone(),
+                launch_time,
                 props: props.clone(),
             })
             .collect())
diff --git a/ballista/rust/core/src/serde/scheduler/mod.rs b/ballista/rust/core/src/serde/scheduler/mod.rs
index 7a710f49..370dd334 100644
--- a/ballista/rust/core/src/serde/scheduler/mod.rs
+++ b/ballista/rust/core/src/serde/scheduler/mod.rs
@@ -65,6 +65,7 @@ impl PartitionId {
 
 #[derive(Debug, Clone)]
 pub struct PartitionLocation {
+    pub map_partition_id: usize,
     pub partition_id: PartitionId,
     pub executor_meta: ExecutorMetadata,
     pub partition_stats: PartitionStats,
@@ -271,19 +272,17 @@ impl ExecutePartitionResult {
     }
 }
 
-/// Unique identifier for the output partitions of a set of tasks.
-#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
-pub struct PartitionIds {
-    pub job_id: String,
-    pub stage_id: usize,
-    pub partition_ids: Vec<usize>,
-}
-
 #[derive(Debug, Clone)]
 pub struct TaskDefinition {
-    pub task_id: PartitionId,
+    pub task_id: usize,
+    pub task_attempt_num: usize,
+    pub job_id: String,
+    pub stage_id: usize,
+    pub stage_attempt_num: usize,
+    pub partition_id: usize,
     pub plan: Vec<u8>,
     pub output_partitioning: Option<PhysicalHashRepartition>,
     pub session_id: String,
+    pub launch_time: u64,
     pub props: HashMap<String, String>,
 }
diff --git a/ballista/rust/core/src/serde/scheduler/to_proto.rs b/ballista/rust/core/src/serde/scheduler/to_proto.rs
index 0c43b533..d453548e 100644
--- a/ballista/rust/core/src/serde/scheduler/to_proto.rs
+++ b/ballista/rust/core/src/serde/scheduler/to_proto.rs
@@ -27,7 +27,7 @@ use crate::serde::protobuf::{
 };
 use crate::serde::scheduler::{
     Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
-    PartitionIds, PartitionLocation, PartitionStats, TaskDefinition,
+    PartitionLocation, PartitionStats, TaskDefinition,
 };
 use datafusion::physical_plan::Partitioning;
 
@@ -74,6 +74,7 @@ impl TryInto<protobuf::PartitionLocation> for PartitionLocation {
 
     fn try_into(self) -> Result<protobuf::PartitionLocation, Self::Error> {
         Ok(protobuf::PartitionLocation {
+            map_partition_id: self.map_partition_id as u32,
             partition_id: Some(self.partition_id.into()),
             executor_meta: Some(self.executor_meta.into()),
             partition_stats: Some(self.partition_stats.into()),
@@ -241,21 +242,6 @@ impl Into<protobuf::ExecutorData> for ExecutorData {
     }
 }
 
-#[allow(clippy::from_over_into)]
-impl Into<protobuf::PartitionIds> for PartitionIds {
-    fn into(self) -> protobuf::PartitionIds {
-        protobuf::PartitionIds {
-            job_id: self.job_id,
-            stage_id: self.stage_id as u32,
-            partition_ids: self
-                .partition_ids
-                .into_iter()
-                .map(|partition_id| partition_id as u32)
-                .collect(),
-        }
-    }
-}
-
 #[allow(clippy::from_over_into)]
 impl Into<protobuf::TaskDefinition> for TaskDefinition {
     fn into(self) -> protobuf::TaskDefinition {
@@ -269,10 +255,16 @@ impl Into<protobuf::TaskDefinition> for TaskDefinition {
             .collect::<Vec<_>>();
 
         protobuf::TaskDefinition {
-            task_id: Some(self.task_id.into()),
+            task_id: self.task_id as u32,
+            task_attempt_num: self.task_attempt_num as u32,
+            job_id: self.job_id,
+            stage_id: self.stage_id as u32,
+            stage_attempt_num: self.stage_attempt_num as u32,
+            partition_id: self.partition_id as u32,
             plan: self.plan,
             output_partitioning: self.output_partitioning,
             session_id: self.session_id,
+            launch_time: self.launch_time,
             props,
         }
     }
diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs
index 20f8b927..21626c55 100644
--- a/ballista/rust/core/src/utils.rs
+++ b/ballista/rust/core/src/utils.rs
@@ -49,6 +49,7 @@ use datafusion_proto::logical_plan::{
     AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec,
 };
 use futures::StreamExt;
+use log::error;
 #[cfg(feature = "s3")]
 use object_store::aws::AmazonS3Builder;
 use object_store::ObjectStore;
@@ -122,10 +123,8 @@ pub async fn write_stream_to_disk(
     disk_write_metric: &metrics::Time,
 ) -> Result<PartitionStats> {
     let file = File::create(&path).map_err(|e| {
-        BallistaError::General(format!(
-            "Failed to create partition file at {}: {:?}",
-            path, e
-        ))
+        error!("Failed to create partition file at {}: {:?}", path, e);
+        BallistaError::IoError(e)
     })?;
 
     let mut num_rows = 0;
diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs
index f7e029c2..3afe9b91 100644
--- a/ballista/rust/executor/src/execution_loop.rs
+++ b/ballista/rust/executor/src/execution_loop.rs
@@ -22,11 +22,11 @@ use ballista_core::serde::protobuf::{
     TaskDefinition, TaskStatus,
 };
 
-use crate::as_task_status;
 use crate::executor::Executor;
+use crate::{as_task_status, TaskExecutionTimes};
 use ballista_core::error::BallistaError;
 use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
-use ballista_core::serde::scheduler::ExecutorSpecification;
+use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId};
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use ballista_core::utils::collect_plan_metrics;
 use datafusion::execution::context::TaskContext;
@@ -40,6 +40,7 @@ use std::error::Error;
 use std::ops::Deref;
 use std::sync::atomic::{AtomicUsize, Ordering};
 use std::sync::mpsc::{Receiver, Sender, TryRecvError};
+use std::time::{SystemTime, UNIX_EPOCH};
 use std::{sync::Arc, time::Duration};
 use tonic::transport::Channel;
 
@@ -136,16 +137,24 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
     task: TaskDefinition,
     codec: &BallistaCodec<T, U>,
 ) -> Result<(), BallistaError> {
-    let task_id = task.task_id.unwrap();
-    let task_id_log = format!(
-        "{}/{}/{}",
-        task_id.job_id, task_id.stage_id, task_id.partition_id
+    let task_id = task.task_id;
+    let task_attempt_num = task.task_attempt_num;
+    let job_id = task.job_id;
+    let stage_id = task.stage_id;
+    let stage_attempt_num = task.stage_attempt_num;
+    let task_launch_time = task.launch_time;
+    let partition_id = task.partition_id;
+    let start_exec_time = SystemTime::now()
+        .duration_since(UNIX_EPOCH)
+        .unwrap()
+        .as_millis() as u64;
+    let task_identity = format!(
+        "TID {} {}/{}.{}/{}.{}",
+        task_id, job_id, stage_id, stage_attempt_num, partition_id, task_attempt_num
     );
-    info!("Received task {}", task_id_log);
+    info!("Received task {}", task_identity);
     available_tasks_slots.fetch_sub(1, Ordering::SeqCst);
 
-    let runtime = executor.runtime.clone();
-    let session_id = task.session_id;
     let mut task_props = HashMap::new();
     for kv_pair in task.props {
         task_props.insert(kv_pair.key, kv_pair.value);
@@ -160,8 +169,10 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
     for agg_func in executor.aggregate_functions.clone() {
         task_aggregate_functions.insert(agg_func.0, agg_func.1);
     }
+    let runtime = executor.runtime.clone();
+    let session_id = task.session_id.clone();
     let task_context = Arc::new(TaskContext::new(
-        task_id_log.clone(),
+        task_identity.clone(),
         session_id,
         task_props,
         task_scalar_functions,
@@ -184,17 +195,19 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
         plan.schema().as_ref(),
     )?;
 
-    let shuffle_writer_plan = executor.new_shuffle_writer(
-        task_id.job_id.clone(),
-        task_id.stage_id as usize,
-        plan,
-    )?;
+    let shuffle_writer_plan =
+        executor.new_shuffle_writer(job_id.clone(), stage_id as usize, plan)?;
     tokio::spawn(async move {
         use std::panic::AssertUnwindSafe;
+        let part = PartitionId {
+            job_id: job_id.clone(),
+            stage_id: stage_id as usize,
+            partition_id: partition_id as usize,
+        };
+
         let execution_result = match AssertUnwindSafe(executor.execute_shuffle_write(
-            task_id.job_id.clone(),
-            task_id.stage_id as usize,
-            task_id.partition_id as usize,
+            task_id as usize,
+            part.clone(),
             shuffle_writer_plan.clone(),
             task_context,
             shuffle_output_partitioning,
@@ -210,7 +223,7 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
             }
         };
 
-        info!("Done with task {}", task_id_log);
+        info!("Done with task {}", task_identity);
         debug!("Statistics: {:?}", execution_result);
         available_tasks_slots.fetch_add(1, Ordering::SeqCst);
 
@@ -221,11 +234,25 @@ async fn run_received_tasks<T: 'static + AsLogicalPlan, U: 'static + AsExecution
             .collect::<Result<Vec<_>, BallistaError>>()
             .ok();
 
+        let end_exec_time = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .unwrap()
+            .as_millis() as u64;
+
+        let task_execution_times = TaskExecutionTimes {
+            launch_time: task_launch_time,
+            start_exec_time,
+            end_exec_time,
+        };
+
         let _ = task_status_sender.send(as_task_status(
             execution_result,
             executor.metadata.id.clone(),
-            task_id,
+            task_id as usize,
+            stage_attempt_num as usize,
+            part,
             operator_metrics,
+            task_execution_times,
         ));
     });
 
diff --git a/ballista/rust/executor/src/executor.rs b/ballista/rust/executor/src/executor.rs
index ff1cde80..47f5fe33 100644
--- a/ballista/rust/executor/src/executor.rs
+++ b/ballista/rust/executor/src/executor.rs
@@ -37,7 +37,7 @@ use futures::future::AbortHandle;
 use ballista_core::serde::scheduler::PartitionId;
 use tokio::sync::Mutex;
 
-type AbortHandles = Arc<Mutex<HashMap<PartitionId, AbortHandle>>>;
+type AbortHandles = Arc<Mutex<HashMap<(usize, PartitionId), AbortHandle>>>;
 
 /// Ballista executor
 #[derive(Clone)]
@@ -96,39 +96,34 @@ impl Executor {
     /// and statistics.
     pub async fn execute_shuffle_write(
         &self,
-        job_id: String,
-        stage_id: usize,
-        part: usize,
+        task_id: usize,
+        partition: PartitionId,
         shuffle_writer: Arc<ShuffleWriterExec>,
         task_ctx: Arc<TaskContext>,
         _shuffle_output_partitioning: Option<Partitioning>,
     ) -> Result<Vec<protobuf::ShuffleWritePartition>, BallistaError> {
         let (task, abort_handle) = futures::future::abortable(
-            shuffle_writer.execute_shuffle_write(part, task_ctx),
+            shuffle_writer.execute_shuffle_write(partition.partition_id, task_ctx),
         );
 
         {
             let mut abort_handles = self.abort_handles.lock().await;
-            abort_handles.insert(
-                PartitionId {
-                    job_id: job_id.clone(),
-                    stage_id,
-                    partition_id: part,
-                },
-                abort_handle,
-            );
+            abort_handles.insert((task_id, partition.clone()), abort_handle);
         }
 
         let partitions = task.await??;
 
-        self.abort_handles.lock().await.remove(&PartitionId {
-            job_id: job_id.clone(),
-            stage_id,
-            partition_id: part,
-        });
+        self.abort_handles
+            .lock()
+            .await
+            .remove(&(task_id, partition.clone()));
 
-        self.metrics_collector
-            .record_stage(&job_id, stage_id, part, shuffle_writer);
+        self.metrics_collector.record_stage(
+            &partition.job_id,
+            partition.stage_id,
+            partition.partition_id,
+            shuffle_writer,
+        );
 
         Ok(partitions)
     }
@@ -162,15 +157,19 @@ impl Executor {
 
     pub async fn cancel_task(
         &self,
+        task_id: usize,
         job_id: String,
         stage_id: usize,
         partition_id: usize,
     ) -> Result<bool, BallistaError> {
-        if let Some(handle) = self.abort_handles.lock().await.remove(&PartitionId {
-            job_id,
-            stage_id,
-            partition_id,
-        }) {
+        if let Some(handle) = self.abort_handles.lock().await.remove(&(
+            task_id,
+            PartitionId {
+                job_id,
+                stage_id,
+                partition_id,
+            },
+        )) {
             handle.abort();
             Ok(true)
         } else {
@@ -194,6 +193,7 @@ mod test {
     use ballista_core::serde::protobuf::ExecutorRegistration;
     use datafusion::execution::context::TaskContext;
 
+    use ballista_core::serde::scheduler::PartitionId;
     use datafusion::physical_expr::PhysicalSortExpr;
     use datafusion::physical_plan::{
         ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream,
@@ -314,11 +314,15 @@ mod test {
         // Spawn our non-terminating task on a separate fiber.
         let executor_clone = executor.clone();
         tokio::task::spawn(async move {
+            let part = PartitionId {
+                job_id: "job-id".to_owned(),
+                stage_id: 1,
+                partition_id: 0,
+            };
             let task_result = executor_clone
                 .execute_shuffle_write(
-                    "job-id".to_owned(),
                     1,
-                    0,
+                    part,
                     Arc::new(shuffle_write),
                     ctx.task_ctx(),
                     None,
@@ -331,7 +335,7 @@ mod test {
         // poll until that happens.
         for _ in 0..20 {
             if executor
-                .cancel_task("job-id".to_owned(), 1, 0)
+                .cancel_task(1, "job-id".to_owned(), 1, 0)
                 .await
                 .expect("cancelling task")
             {
diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs
index 4544d8c2..2bc84f72 100644
--- a/ballista/rust/executor/src/executor_server.rs
+++ b/ballista/rust/executor/src/executor_server.rs
@@ -39,6 +39,7 @@ use ballista_core::serde::protobuf::{
     LaunchMultiTaskResult, LaunchTaskParams, LaunchTaskResult, RegisterExecutorParams,
     StopExecutorParams, StopExecutorResult, TaskStatus, UpdateTaskStatusParams,
 };
+use ballista_core::serde::scheduler::PartitionId;
 use ballista_core::serde::scheduler::TaskDefinition;
 use ballista_core::serde::{AsExecutionPlan, BallistaCodec};
 use ballista_core::utils::{
@@ -50,10 +51,10 @@ use datafusion_proto::logical_plan::AsLogicalPlan;
 use tokio::sync::mpsc::error::TryRecvError;
 use tokio::task::JoinHandle;
 
-use crate::as_task_status;
 use crate::cpu_bound_executor::DedicatedExecutor;
 use crate::executor::Executor;
 use crate::shutdown::ShutdownNotifier;
+use crate::{as_task_status, TaskExecutionTimes};
 
 type ServerHandle = JoinHandle<Result<(), BallistaError>>;
 type SchedulerClients = Arc<RwLock<HashMap<String, SchedulerGrpcClient<Channel>>>>;
@@ -283,19 +284,15 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
 
     async fn run_task(
         &self,
+        task_identity: String,
         curator_task: CuratorTaskDefinition,
     ) -> Result<(), BallistaError> {
-        let scheduler_id = curator_task.scheduler_id;
+        let start_exec_time = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .unwrap()
+            .as_millis() as u64;
+        info!("Start to run task {}", task_identity);
         let task = curator_task.task;
-        let task_id = task.task_id;
-        let task_id_log = format!(
-            "{}/{}/{}",
-            task_id.job_id, task_id.stage_id, task_id.partition_id
-        );
-        info!("Start to run task {}", task_id_log);
-
-        let runtime = self.executor.runtime.clone();
-        let session_id = task.session_id;
         let task_props = task.props;
 
         let mut task_scalar_functions = HashMap::new();
@@ -307,8 +304,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
         for agg_func in self.executor.aggregate_functions.clone() {
             task_aggregate_functions.insert(agg_func.0, agg_func.1);
         }
+
+        let session_id = task.session_id;
+        let runtime = self.executor.runtime.clone();
         let task_context = Arc::new(TaskContext::new(
-            task_id_log.clone(),
+            task_identity.clone(),
             session_id,
             task_props,
             task_scalar_functions,
@@ -333,24 +333,32 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             plan.schema().as_ref(),
         )?;
 
-        let shuffle_writer_plan = self.executor.new_shuffle_writer(
-            task_id.job_id.clone(),
-            task_id.stage_id as usize,
-            plan,
-        )?;
+        let task_id = task.task_id;
+        let job_id = task.job_id;
+        let stage_id = task.stage_id;
+        let stage_attempt_num = task.stage_attempt_num;
+        let partition_id = task.partition_id;
+        let shuffle_writer_plan =
+            self.executor
+                .new_shuffle_writer(job_id.clone(), stage_id as usize, plan)?;
+
+        let part = PartitionId {
+            job_id: job_id.clone(),
+            stage_id: stage_id as usize,
+            partition_id: partition_id as usize,
+        };
 
         let execution_result = self
             .executor
             .execute_shuffle_write(
-                task_id.job_id.clone(),
-                task_id.stage_id as usize,
-                task_id.partition_id as usize,
+                task_id as usize,
+                part.clone(),
                 shuffle_writer_plan.clone(),
                 task_context,
                 shuffle_output_partitioning,
             )
             .await;
-        info!("Done with task {}", task_id_log);
+        info!("Done with task {}", task_identity);
         debug!("Statistics: {:?}", execution_result);
 
         let plan_metrics = collect_plan_metrics(shuffle_writer_plan.as_ref());
@@ -359,13 +367,28 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorServer<T,
             .map(|m| m.try_into())
             .collect::<Result<Vec<_>, BallistaError>>()?;
         let executor_id = &self.executor.metadata.id;
+
+        let end_exec_time = SystemTime::now()
+            .duration_since(UNIX_EPOCH)
+            .unwrap()
+            .as_millis() as u64;
+        let task_execution_times = TaskExecutionTimes {
+            launch_time: task.launch_time,
+            start_exec_time,
+            end_exec_time,
+        };
+
         let task_status = as_task_status(
             execution_result,
             executor_id.clone(),
-            task_id.into(),
+            task_id,
+            stage_attempt_num,
+            part,
             Some(operator_metrics),
+            task_execution_times,
         );
 
+        let scheduler_id = curator_task.scheduler_id;
         let task_status_sender = self.executor_env.tx_task_status.clone();
         task_status_sender
             .send(CuratorTaskStatus {
@@ -549,21 +572,28 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskRunnerPool<T,
                     }
                 };
                 if let Some(curator_task) = maybe_task {
-                    let task_id = &curator_task.task.task_id;
-                    let task_id_log = format!(
-                        "{}/{}/{}",
-                        task_id.job_id, task_id.stage_id, task_id.partition_id
+                    let task_identity = format!(
+                        "TID {} {}/{}.{}/{}.{}",
+                        &curator_task.task.task_id,
+                        &curator_task.task.job_id,
+                        &curator_task.task.stage_id,
+                        &curator_task.task.stage_attempt_num,
+                        &curator_task.task.partition_id,
+                        &curator_task.task.task_attempt_num,
                     );
-                    info!("Received task {:?}", &task_id_log);
+                    info!("Received task {:?}", &task_identity);
 
                     let server = executor_server.clone();
                     dedicated_executor.spawn(async move {
-                        server.run_task(curator_task).await.unwrap_or_else(|e| {
-                            error!(
-                                "Fail to run the task {:?} due to {:?}",
-                                task_id_log, e
-                            );
-                        });
+                        server
+                            .run_task(task_identity.clone(), curator_task)
+                            .await
+                            .unwrap_or_else(|e| {
+                                error!(
+                                    "Fail to run the task {:?} due to {:?}",
+                                    task_identity, e
+                                );
+                            });
                     });
                 } else {
                     info!("Channel is closed and will exit the task receive loop");
@@ -650,18 +680,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> ExecutorGrpc
         &self,
         request: Request<CancelTasksParams>,
     ) -> Result<Response<CancelTasksResult>, Status> {
-        let partitions = request.into_inner().partition_id;
-        info!("Cancelling partition tasks for {:?}", partitions);
+        let task_infos = request.into_inner().task_infos;
+        info!("Cancelling tasks for {:?}", task_infos);
 
         let mut cancelled = true;
 
-        for partition in partitions {
+        for task in task_infos {
             if let Err(e) = self
                 .executor
                 .cancel_task(
-                    partition.job_id,
-                    partition.stage_id as usize,
-                    partition.partition_id as usize,
+                    task.task_id as usize,
+                    task.job_id,
+                    task.stage_id as usize,
+                    task.partition_id as usize,
                 )
                 .await
             {
diff --git a/ballista/rust/executor/src/lib.rs b/ballista/rust/executor/src/lib.rs
index 34d43b4c..a86d0bfb 100644
--- a/ballista/rust/executor/src/lib.rs
+++ b/ballista/rust/executor/src/lib.rs
@@ -34,15 +34,26 @@ pub use standalone::new_standalone_executor;
 use log::info;
 
 use ballista_core::serde::protobuf::{
-    task_status, CompletedTask, FailedTask, OperatorMetricsSet, PartitionId,
-    ShuffleWritePartition, TaskStatus,
+    task_status, FailedTask, OperatorMetricsSet, ShuffleWritePartition, SuccessfulTask,
+    TaskStatus,
 };
+use ballista_core::serde::scheduler::PartitionId;
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub struct TaskExecutionTimes {
+    launch_time: u64,
+    start_exec_time: u64,
+    end_exec_time: u64,
+}
 
 pub fn as_task_status(
     execution_result: ballista_core::error::Result<Vec<ShuffleWritePartition>>,
     executor_id: String,
-    task_id: PartitionId,
+    task_id: usize,
+    stage_attempt_num: usize,
+    partition_id: PartitionId,
     operator_metrics: Option<Vec<OperatorMetricsSet>>,
+    execution_times: TaskExecutionTimes,
 ) -> TaskStatus {
     let metrics = operator_metrics.unwrap_or_default();
     match execution_result {
@@ -53,9 +64,16 @@ pub fn as_task_status(
                 metrics.len()
             );
             TaskStatus {
-                task_id: Some(task_id),
+                task_id: task_id as u32,
+                job_id: partition_id.job_id,
+                stage_id: partition_id.stage_id as u32,
+                stage_attempt_num: stage_attempt_num as u32,
+                partition_id: partition_id.partition_id as u32,
+                launch_time: execution_times.launch_time,
+                start_exec_time: execution_times.start_exec_time,
+                end_exec_time: execution_times.end_exec_time,
                 metrics,
-                status: Some(task_status::Status::Completed(CompletedTask {
+                status: Some(task_status::Status::Successful(SuccessfulTask {
                     executor_id,
                     partitions,
                 })),
@@ -66,11 +84,16 @@ pub fn as_task_status(
             info!("Task {:?} failed: {}", task_id, error_msg);
 
             TaskStatus {
-                task_id: Some(task_id),
+                task_id: task_id as u32,
+                job_id: partition_id.job_id,
+                stage_id: partition_id.stage_id as u32,
+                stage_attempt_num: stage_attempt_num as u32,
+                partition_id: partition_id.partition_id as u32,
+                launch_time: execution_times.launch_time,
+                start_exec_time: execution_times.start_exec_time,
+                end_exec_time: execution_times.end_exec_time,
                 metrics,
-                status: Some(task_status::Status::Failed(FailedTask {
-                    error: format!("Task failed due to Tokio error: {}", error_msg),
-                })),
+                status: Some(task_status::Status::Failed(FailedTask::from(e))),
             }
         }
     }
diff --git a/ballista/rust/scheduler/src/api/handlers.rs b/ballista/rust/scheduler/src/api/handlers.rs
index 80547c0d..440a7ab8 100644
--- a/ballista/rust/scheduler/src/api/handlers.rs
+++ b/ballista/rust/scheduler/src/api/handlers.rs
@@ -92,7 +92,7 @@ pub(crate) async fn get_jobs<T: AsLogicalPlan, U: AsExecutionPlan>(
                 Some(Status::Queued(_)) => "Queued".to_string(),
                 Some(Status::Running(_)) => "Running".to_string(),
                 Some(Status::Failed(error)) => format!("Failed: {}", error.error),
-                Some(Status::Completed(completed)) => {
+                Some(Status::Successful(completed)) => {
                     let num_rows = completed
                         .partition_location
                         .iter()
diff --git a/ballista/rust/scheduler/src/flight_sql.rs b/ballista/rust/scheduler/src/flight_sql.rs
index a6218c48..93ade1d4 100644
--- a/ballista/rust/scheduler/src/flight_sql.rs
+++ b/ballista/rust/scheduler/src/flight_sql.rs
@@ -47,9 +47,9 @@ use ballista_core::config::BallistaConfig;
 use ballista_core::serde::protobuf;
 use ballista_core::serde::protobuf::action::ActionType::FetchPartition;
 use ballista_core::serde::protobuf::job_status;
-use ballista_core::serde::protobuf::CompletedJob;
 use ballista_core::serde::protobuf::JobStatus;
 use ballista_core::serde::protobuf::PhysicalPlanNode;
+use ballista_core::serde::protobuf::SuccessfulJob;
 use ballista_core::utils::create_grpc_client_connection;
 use datafusion::arrow;
 use datafusion::arrow::datatypes::Schema;
@@ -146,7 +146,7 @@ impl FlightSqlServiceImpl {
         Ok(plan)
     }
 
-    async fn check_job(&self, job_id: &String) -> Result<Option<CompletedJob>, Status> {
+    async fn check_job(&self, job_id: &String) -> Result<Option<SuccessfulJob>, Status> {
         let status = self
             .server
             .state
@@ -184,13 +184,13 @@ impl FlightSqlServiceImpl {
                     e.error
                 )))?
             }
-            job_status::Status::Completed(comp) => Ok(Some(comp)),
+            job_status::Status::Successful(comp) => Ok(Some(comp)),
         }
     }
 
     async fn job_to_fetch_part(
         &self,
-        completed: CompletedJob,
+        completed: SuccessfulJob,
         num_rows: &mut i64,
         num_bytes: &mut i64,
     ) -> Result<Vec<FlightEndpoint>, Status> {
diff --git a/ballista/rust/scheduler/src/scheduler_server/event.rs b/ballista/rust/scheduler/src/scheduler_server/event.rs
index ad462944..c748bc75 100644
--- a/ballista/rust/scheduler/src/scheduler_server/event.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/event.rs
@@ -19,6 +19,7 @@ use crate::state::executor_manager::ExecutorReservation;
 
 use datafusion::logical_plan::LogicalPlan;
 
+use crate::state::execution_graph::RunningTaskInfo;
 use ballista_core::serde::protobuf::TaskStatus;
 use datafusion::prelude::SessionContext;
 use std::sync::Arc;
@@ -35,9 +36,10 @@ pub enum QueryStageSchedulerEvent {
     JobPlanningFailed(String, String),
     JobFinished(String),
     // For a job fails with its execution graph setting failed
-    JobRunningFailed(String),
+    JobRunningFailed(String, String),
     JobUpdated(String),
     TaskUpdating(String, Vec<TaskStatus>),
     ReservationOffering(Vec<ExecutorReservation>),
     ExecutorLost(String, Option<String>),
+    CancelTasks(Vec<RunningTaskInfo>),
 }
diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
index b4aef6e6..3563a50e 100644
--- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs
@@ -526,17 +526,21 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc
         let job_id = request.into_inner().job_id;
         info!("Received cancellation request for job {}", job_id);
 
-        self.state
-            .task_manager
-            .cancel_job(&job_id, &self.state.executor_manager)
-            .await
-            .map_err(|e| {
+        match self.state.task_manager.cancel_job(&job_id).await {
+            Ok(tasks) => {
+                self.state.executor_manager.cancel_running_tasks(tasks).await.map_err(|e| {
+                        let msg = format!("Error to cancel running task when cancel the job {} due to {:?}", job_id, e);
+                        error!("{}", msg);
+                        Status::internal(msg)
+                })?;
+                Ok(Response::new(CancelJobResult { cancelled: true }))
+            }
+            Err(e) => {
                 let msg = format!("Error cancelling job {}: {:?}", job_id, e);
-
                 error!("{}", msg);
-                Status::internal(msg)
-            })?;
-        Ok(Response::new(CancelJobResult { cancelled: true }))
+                Ok(Response::new(CancelJobResult { cancelled: false }))
+            }
+        }
     }
 }
 
diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs
index f428cca2..ac291151 100644
--- a/ballista/rust/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs
@@ -292,8 +292,8 @@ mod test {
     use ballista_core::error::Result;
 
     use ballista_core::serde::protobuf::{
-        job_status, task_status, CompletedTask, FailedTask, JobStatus, PartitionId,
-        PhysicalPlanNode, ShuffleWritePartition, TaskStatus,
+        failed_task, job_status, task_status, ExecutionError, FailedTask, JobStatus,
+        PhysicalPlanNode, ShuffleWritePartition, SuccessfulTask, TaskStatus,
     };
     use ballista_core::serde::scheduler::{
         ExecutorData, ExecutorMetadata, ExecutorSpecification,
@@ -372,16 +372,19 @@ mod test {
 
                 // Complete the task
                 let task_status = TaskStatus {
-                    status: Some(task_status::Status::Completed(CompletedTask {
+                    task_id: task.task_id as u32,
+                    job_id: task.partition.job_id.clone(),
+                    stage_id: task.partition.stage_id as u32,
+                    stage_attempt_num: task.stage_attempt_num as u32,
+                    partition_id: task.partition.partition_id as u32,
+                    launch_time: 0,
+                    start_exec_time: 0,
+                    end_exec_time: 0,
+                    metrics: vec![],
+                    status: Some(task_status::Status::Successful(SuccessfulTask {
                         executor_id: "executor-1".to_owned(),
                         partitions,
                     })),
-                    metrics: vec![],
-                    task_id: Some(PartitionId {
-                        job_id: job_id.to_owned(),
-                        stage_id: task.partition.stage_id as u32,
-                        partition_id: task.partition.partition_id as u32,
-                    }),
                 };
 
                 scheduler
@@ -401,7 +404,7 @@ mod test {
             .expect("Fail to find graph in the cache");
 
         let final_graph = final_graph.read().await;
-        assert!(final_graph.complete());
+        assert!(final_graph.is_successful());
         assert_eq!(final_graph.output_locations().len(), 4);
 
         for output_location in final_graph.output_locations() {
@@ -452,7 +455,7 @@ mod test {
                     .await
                     .unwrap();
                 let graph = graph.read().await;
-                if graph.complete() {
+                if graph.is_successful() {
                     break;
                 }
                 graph.available_tasks()
@@ -506,18 +509,21 @@ mod test {
 
                                 // Complete the task
                                 let task_status = TaskStatus {
-                                    status: Some(task_status::Status::Completed(
-                                        CompletedTask {
+                                    task_id: task.task_id as u32,
+                                    job_id: task.partition.job_id.clone(),
+                                    stage_id: task.partition.stage_id as u32,
+                                    stage_attempt_num: task.stage_attempt_num as u32,
+                                    partition_id: task.partition.partition_id as u32,
+                                    launch_time: 0,
+                                    start_exec_time: 0,
+                                    end_exec_time: 0,
+                                    metrics: vec![],
+                                    status: Some(task_status::Status::Successful(
+                                        SuccessfulTask {
                                             executor_id: executor.id.clone(),
                                             partitions,
                                         },
                                     )),
-                                    metrics: vec![],
-                                    task_id: Some(PartitionId {
-                                        job_id: job_id.to_owned(),
-                                        stage_id: task.partition.stage_id as u32,
-                                        partition_id: task.partition.partition_id as u32,
-                                    }),
                                 };
 
                                 scheduler
@@ -552,7 +558,7 @@ mod test {
             .get_execution_graph(job_id)
             .await?;
 
-        assert!(final_graph.complete());
+        assert!(final_graph.is_successful());
         assert_eq!(final_graph.output_locations().len(), 4);
 
         Ok(())
@@ -637,15 +643,25 @@ mod test {
 
                             // Complete the task
                             let task_status = TaskStatus {
+                                task_id: task.task_id as u32,
+                                job_id: task.partition.job_id.clone(),
+                                stage_id: task.partition.stage_id as u32,
+                                stage_attempt_num: task.stage_attempt_num as u32,
+                                partition_id: task.partition.partition_id as u32,
+                                launch_time: 0,
+                                start_exec_time: 0,
+                                end_exec_time: 0,
+                                metrics: vec![],
                                 status: Some(task_status::Status::Failed(FailedTask {
                                     error: "".to_string(),
+                                    retryable: false,
+                                    count_to_failures: false,
+                                    failed_reason: Some(
+                                        failed_task::FailedReason::ExecutionError(
+                                            ExecutionError {},
+                                        ),
+                                    ),
                                 })),
-                                metrics: vec![],
-                                task_id: Some(PartitionId {
-                                    job_id: job_id.to_owned(),
-                                    stage_id: task.partition.stage_id as u32,
-                                    partition_id: task.partition.partition_id as u32,
-                                }),
                             };
 
                             scheduler
diff --git a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
index 7d186fcd..02a22d5a 100644
--- a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
+++ b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs
@@ -125,20 +125,29 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
                         .await?;
                 }
             }
-            QueryStageSchedulerEvent::JobPlanningFailed(job_id, fail_message) => {
-                error!("Job {} failed: {}", job_id, fail_message);
+            QueryStageSchedulerEvent::JobPlanningFailed(job_id, failure_reason) => {
+                error!("Job {} failed: {}", job_id, failure_reason);
                 self.state
                     .task_manager
-                    .fail_job(&job_id, fail_message)
+                    .fail_unscheduled_job(&job_id, failure_reason)
                     .await?;
             }
             QueryStageSchedulerEvent::JobFinished(job_id) => {
-                info!("Job {} complete", job_id);
-                self.state.task_manager.complete_job(&job_id).await?;
+                info!("Job {} success", job_id);
+                self.state.task_manager.succeed_job(&job_id).await?;
             }
-            QueryStageSchedulerEvent::JobRunningFailed(job_id) => {
+            QueryStageSchedulerEvent::JobRunningFailed(job_id, failure_reason) => {
                 error!("Job {} running failed", job_id);
-                self.state.task_manager.fail_running_job(&job_id).await?;
+                let tasks = self
+                    .state
+                    .task_manager
+                    .abort_job(&job_id, failure_reason)
+                    .await?;
+                if !tasks.is_empty() {
+                    tx_event
+                        .post_event(QueryStageSchedulerEvent::CancelTasks(tasks))
+                        .await?;
+                }
             }
             QueryStageSchedulerEvent::JobUpdated(job_id) => {
                 info!("Job {} Updated", job_id);
@@ -184,17 +193,28 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
                 }
             }
             QueryStageSchedulerEvent::ExecutorLost(executor_id, _) => {
-                self.state
-                    .task_manager
-                    .executor_lost(&executor_id)
-                    .await
-                    .unwrap_or_else(|e| {
+                match self.state.task_manager.executor_lost(&executor_id).await {
+                    Ok(tasks) => {
+                        if !tasks.is_empty() {
+                            tx_event
+                                .post_event(QueryStageSchedulerEvent::CancelTasks(tasks))
+                                .await?;
+                        }
+                    }
+                    Err(e) => {
                         let msg = format!(
                             "TaskManager error to handle Executor {} lost: {}",
                             executor_id, e
                         );
                         error!("{}", msg);
-                    });
+                    }
+                }
+            }
+            QueryStageSchedulerEvent::CancelTasks(tasks) => {
+                self.state
+                    .executor_manager
+                    .cancel_running_tasks(tasks)
+                    .await?
             }
         }
 
diff --git a/ballista/rust/scheduler/src/state/execution_graph.rs b/ballista/rust/scheduler/src/state/execution_graph.rs
index 4835b850..a6c5c7db 100644
--- a/ballista/rust/scheduler/src/state/execution_graph.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph.rs
@@ -18,7 +18,9 @@
 use std::collections::{HashMap, HashSet};
 use std::convert::TryInto;
 use std::fmt::{Debug, Formatter};
+use std::iter::FromIterator;
 use std::sync::Arc;
+use std::time::{SystemTime, UNIX_EPOCH};
 
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
 use datafusion::physical_plan::{
@@ -30,9 +32,10 @@ use log::{error, info, warn};
 
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::execution_plans::{ShuffleWriterExec, UnresolvedShuffleExec};
+use ballista_core::serde::protobuf::failed_task::FailedReason;
 use ballista_core::serde::protobuf::{
-    self, execution_graph_stage::StageType, CompletedJob, JobStatus, QueuedJob,
-    TaskStatus,
+    self, execution_graph_stage::StageType, FailedTask, JobStatus, QueuedJob, ResultLost,
+    SuccessfulJob, TaskStatus,
 };
 use ballista_core::serde::protobuf::{job_status, FailedJob, ShuffleWritePartition};
 use ballista_core::serde::protobuf::{task_status, RunningTask};
@@ -45,9 +48,10 @@ use crate::display::print_stage_metrics;
 use crate::planner::DistributedPlanner;
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 pub(crate) use crate::state::execution_graph::execution_stage::{
-    CompletedStage, ExecutionStage, FailedStage, ResolvedStage, StageOutput,
+    ExecutionStage, FailedStage, ResolvedStage, StageOutput, SuccessfulStage, TaskInfo,
     UnresolvedStage,
 };
+use crate::state::task_manager::UpdatedStages;
 
 mod execution_stage;
 
@@ -110,6 +114,20 @@ pub struct ExecutionGraph {
     output_partitions: usize,
     /// Locations of this `ExecutionGraph` final output locations
     output_locations: Vec<PartitionLocation>,
+    /// Task ID generator, generate unique TID in the execution graph
+    task_id_gen: usize,
+    /// Failed stage attempts, record the failed stage attempts to limit the retry times.
+    /// Map from Stage ID -> Set<Stage_ATTPMPT_NUM>
+    failed_stage_attempts: HashMap<usize, HashSet<usize>>,
+}
+
+#[derive(Clone)]
+pub struct RunningTaskInfo {
+    pub task_id: usize,
+    pub job_id: String,
+    pub stage_id: usize,
+    pub partition_id: usize,
+    pub executor_id: String,
 }
 
 impl ExecutionGraph {
@@ -138,6 +156,8 @@ impl ExecutionGraph {
             stages,
             output_partitions,
             output_locations: vec![],
+            task_id_gen: 0,
+            failed_stage_attempts: HashMap::new(),
         })
     }
 
@@ -157,15 +177,21 @@ impl ExecutionGraph {
         self.stages.len()
     }
 
+    pub fn next_task_id(&mut self) -> usize {
+        let new_tid = self.task_id_gen;
+        self.task_id_gen += 1;
+        new_tid
+    }
+
     pub(crate) fn stages(&self) -> &HashMap<usize, ExecutionStage> {
         &self.stages
     }
 
-    /// An ExecutionGraph is complete if all its stages are complete
-    pub fn complete(&self) -> bool {
+    /// An ExecutionGraph is successful if all its stages are successful
+    pub fn is_successful(&self) -> bool {
         self.stages
             .values()
-            .all(|s| matches!(s, ExecutionStage::Completed(_)))
+            .all(|s| matches!(s, ExecutionStage::Successful(_)))
     }
 
     /// Revive the execution graph by converting the resolved stages to running stages
@@ -202,68 +228,183 @@ impl ExecutionGraph {
         &mut self,
         executor: &ExecutorMetadata,
         task_statuses: Vec<TaskStatus>,
-    ) -> Result<Option<QueryStageSchedulerEvent>> {
+        max_task_failures: usize,
+        max_stage_failures: usize,
+    ) -> Result<Vec<QueryStageSchedulerEvent>> {
         let job_id = self.job_id().to_owned();
         // First of all, classify the statuses by stages
         let mut job_task_statuses: HashMap<usize, Vec<TaskStatus>> = HashMap::new();
         for task_status in task_statuses {
-            if let Some(task_id) = task_status.task_id.as_ref() {
-                if task_id.job_id != job_id {
-                    return Err(BallistaError::Internal(format!(
-                        "Error updating job {}: Invalid task status job ID {}",
-                        job_id, task_id.job_id
-                    )));
-                }
-                let stage_task_statuses = job_task_statuses
-                    .entry(task_id.stage_id as usize)
-                    .or_insert_with(Vec::new);
-                stage_task_statuses.push(task_status);
-            } else {
-                error!("There's no task id when updating status");
-            }
+            let stage_id = task_status.stage_id as usize;
+            let stage_task_statuses =
+                job_task_statuses.entry(stage_id).or_insert_with(Vec::new);
+            stage_task_statuses.push(task_status);
         }
 
         // Revive before updating due to some updates not saved
         // It will be refined later
         self.revive();
 
-        let mut events = vec![];
+        let current_running_stages: HashSet<usize> =
+            HashSet::from_iter(self.running_stages());
+
+        // Copy the failed stage attempts from self
+        let mut failed_stage_attempts: HashMap<usize, HashSet<usize>> = HashMap::new();
+        for (stage_id, attempts) in self.failed_stage_attempts.iter() {
+            failed_stage_attempts
+                .insert(*stage_id, HashSet::from_iter(attempts.iter().copied()));
+        }
+
+        let mut resolved_stages = HashSet::new();
+        let mut successful_stages = HashSet::new();
+        let mut failed_stages = HashMap::new();
+        let mut rollback_running_stages = HashMap::new();
+        let mut resubmit_successful_stages: HashMap<usize, HashSet<usize>> =
+            HashMap::new();
+        let mut reset_running_stages: HashMap<usize, HashSet<usize>> = HashMap::new();
+
         for (stage_id, stage_task_statuses) in job_task_statuses {
             if let Some(stage) = self.stages.get_mut(&stage_id) {
                 if let ExecutionStage::Running(running_stage) = stage {
                     let mut locations = vec![];
                     for task_status in stage_task_statuses.into_iter() {
-                        if let TaskStatus {
-                            task_id:
-                                Some(protobuf::PartitionId {
-                                    job_id,
-                                    stage_id,
-                                    partition_id,
-                                }),
-                            metrics: operator_metrics,
-                            status: Some(status),
-                        } = task_status
                         {
                             let stage_id = stage_id as usize;
-                            let partition_id = partition_id as usize;
+                            let task_stage_attempt_num =
+                                task_status.stage_attempt_num as usize;
+                            if task_stage_attempt_num < running_stage.stage_attempt_num {
+                                warn!("Ignore TaskStatus update with TID {} as it's from Stage {}.{} and there is a more recent stage attempt {}.{} running",
+                                    task_status.task_id, stage_id, task_stage_attempt_num, stage_id, running_stage.stage_attempt_num);
+                                continue;
+                            }
+                            let partition_id = task_status.clone().partition_id as usize;
+                            let task_identity = format!(
+                                "TID {} {}/{}.{}/{}",
+                                task_status.task_id,
+                                job_id,
+                                stage_id,
+                                task_stage_attempt_num,
+                                partition_id
+                            );
+                            let operator_metrics = task_status.metrics.clone();
 
-                            running_stage
-                                .update_task_status(partition_id, status.clone());
+                            if !running_stage
+                                .update_task_info(partition_id, task_status.clone())
+                            {
+                                continue;
+                            }
 
-                            // TODO Should be able to reschedule this task.
-                            if let task_status::Status::Failed(failed_task) = status {
-                                events.push(StageEvent::StageFailed(
-                                    stage_id,
-                                    format!(
-                                        "Task {}/{}/{} failed: {}",
-                                        job_id, stage_id, partition_id, failed_task.error
-                                    ),
-                                ));
-                                break;
-                            } else if let task_status::Status::Completed(completed_task) =
-                                status
+                            if let Some(task_status::Status::Failed(failed_task)) =
+                                task_status.status
                             {
-                                // update task metrics for completed task
+                                let failed_reason = failed_task.failed_reason;
+
+                                match failed_reason {
+                                    Some(FailedReason::FetchPartitionError(
+                                        fetch_partiton_error,
+                                    )) => {
+                                        let failed_attempts = failed_stage_attempts
+                                            .entry(stage_id)
+                                            .or_insert_with(HashSet::new);
+                                        failed_attempts.insert(task_stage_attempt_num);
+                                        if failed_attempts.len() < max_stage_failures {
+                                            let map_stage_id = fetch_partiton_error
+                                                .map_stage_id
+                                                as usize;
+                                            let map_partition_id = fetch_partiton_error
+                                                .map_partition_id
+                                                as usize;
+                                            let executor_id =
+                                                fetch_partiton_error.executor_id;
+
+                                            if !failed_stages.is_empty() {
+                                                let error_msg = format!(
+                                                        "Stages was marked failed, ignore FetchPartitionError from task {}", task_identity);
+                                                warn!("{}", error_msg);
+                                            } else {
+                                                // There are different removal strategies here.
+                                                // We can choose just remove the map_partition_id in the FetchPartitionError, when resubmit the input stage, there are less tasks
+                                                // need to rerun, but this might miss many more bad input partitions, lead to more stage level retries in following.
+                                                // Here we choose remove all the bad input partitions which match the same executor id in this single input stage.
+                                                // There are other more aggressive approaches, like considering the executor is lost and check all the running stages in this graph.
+                                                // Or count the fetch failure number on executor and mark the executor lost globally.
+                                                let removed_map_partitions =
+                                                    running_stage
+                                                        .remove_input_partitions(
+                                                            map_stage_id,
+                                                            map_partition_id,
+                                                            &executor_id,
+                                                        )?;
+
+                                                let failure_reasons =
+                                                    rollback_running_stages
+                                                        .entry(stage_id)
+                                                        .or_insert_with(HashSet::new);
+                                                failure_reasons.insert(executor_id);
+
+                                                let missing_inputs =
+                                                    resubmit_successful_stages
+                                                        .entry(map_stage_id)
+                                                        .or_insert_with(HashSet::new);
+                                                missing_inputs
+                                                    .extend(removed_map_partitions);
+                                                warn!("Need to resubmit the current running Stage {} and its map Stage {} due to FetchPartitionError from task {}",
+                                                    stage_id, map_stage_id, task_identity)
+                                            }
+                                        } else {
+                                            let error_msg = format!(
+                                                "Stage {} has failed {} times, \
+                                            most recent failure reason: {:?}",
+                                                stage_id,
+                                                max_stage_failures,
+                                                failed_task.error
+                                            );
+                                            error!("{}", error_msg);
+                                            failed_stages.insert(stage_id, error_msg);
+                                        }
+                                    }
+                                    Some(FailedReason::ExecutionError(_)) => {
+                                        failed_stages.insert(stage_id, failed_task.error);
+                                    }
+                                    Some(_) => {
+                                        if failed_task.retryable
+                                            && failed_task.count_to_failures
+                                        {
+                                            if running_stage
+                                                .task_failure_number(partition_id)
+                                                < max_task_failures
+                                            {
+                                                // TODO add new struct to track all the failed task infos
+                                                // The failure TaskInfo is ignored and set to None here
+                                                running_stage
+                                                    .reset_task_info(partition_id);
+                                            } else {
+                                                let error_msg = format!(
+                        "Task {} in Stage {} failed {} times, fail the stage, most recent failure reason: {:?}",
+                        partition_id, stage_id, max_task_failures, failed_task.error
+                    );
+                                                error!("{}", error_msg);
+                                                failed_stages.insert(stage_id, error_msg);
+                                            }
+                                        } else if failed_task.retryable {
+                                            // TODO add new struct to track all the failed task infos
+                                            // The failure TaskInfo is ignored and set to None here
+                                            running_stage.reset_task_info(partition_id);
+                                        }
+                                    }
+                                    None => {
+                                        let error_msg = format!(
+                                            "Task {} in Stage {} failed with unknown failure reasons, fail the stage",
+                                            partition_id, stage_id);
+                                        error!("{}", error_msg);
+                                        failed_stages.insert(stage_id, error_msg);
+                                    }
+                                }
+                            } else if let Some(task_status::Status::Successful(
+                                successful_task,
+                            )) = task_status.status
+                            {
+                                // update task metrics for successfu task
                                 running_stage.update_task_metrics(
                                     partition_id,
                                     operator_metrics,
@@ -271,19 +412,24 @@ impl ExecutionGraph {
 
                                 locations.append(&mut partition_to_location(
                                     &job_id,
+                                    partition_id,
                                     stage_id,
                                     executor,
-                                    completed_task.partitions,
+                                    successful_task.partitions,
                                 ));
                             } else {
-                                warn!("The task {}/{}/{} with status {:?} is invalid for updating", job_id, stage_id, partition_id, status);
+                                warn!(
+                                    "The task {}'s status is invalid for updating",
+                                    task_identity
+                                );
                             }
                         }
                     }
-                    let is_completed = running_stage.is_completed();
-                    if is_completed {
-                        events.push(StageEvent::StageCompleted(stage_id));
-                        // if this stage is completed, we want to combine the stage metrics to plan's metric set and print out the plan
+                    let is_final_successful = running_stage.is_successful()
+                        && !reset_running_stages.contains_key(&stage_id);
+                    if is_final_successful {
+                        successful_stages.insert(stage_id);
+                        // if this stage is final successful, we want to combine the stage metrics to plan's metric set and print out the plan
                         if let Some(stage_metrics) = running_stage.stage_metrics.as_ref()
                         {
                             print_stage_metrics(
@@ -296,18 +442,105 @@ impl ExecutionGraph {
                     }
 
                     let output_links = running_stage.output_links.clone();
-                    events.append(&mut self.update_stage_output_links(
-                        stage_id,
-                        is_completed,
-                        locations,
-                        output_links,
-                    )?);
+                    resolved_stages.extend(
+                        &mut self
+                            .update_stage_output_links(
+                                stage_id,
+                                is_final_successful,
+                                locations,
+                                output_links,
+                            )?
+                            .into_iter(),
+                    );
+                } else if let ExecutionStage::UnResolved(unsolved_stage) = stage {
+                    for task_status in stage_task_statuses.into_iter() {
+                        let stage_id = stage_id as usize;
+                        let task_stage_attempt_num =
+                            task_status.stage_attempt_num as usize;
+                        let partition_id = task_status.clone().partition_id as usize;
+                        let task_identity = format!(
+                            "TID {} {}/{}.{}/{}",
+                            task_status.task_id,
+                            job_id,
+                            stage_id,
+                            task_stage_attempt_num,
+                            partition_id
+                        );
+                        let mut should_ignore = true;
+                        // handle delayed failed tasks if the stage's next attempt is still in UnResolved status.
+                        if let Some(task_status::Status::Failed(failed_task)) =
+                            task_status.status
+                        {
+                            if unsolved_stage.stage_attempt_num - task_stage_attempt_num
+                                == 1
+                            {
+                                let failed_reason = failed_task.failed_reason;
+                                match failed_reason {
+                                    Some(FailedReason::ExecutionError(_)) => {
+                                        should_ignore = false;
+                                        failed_stages.insert(stage_id, failed_task.error);
+                                    }
+                                    Some(FailedReason::FetchPartitionError(
+                                        fetch_partiton_error,
+                                    )) if failed_stages.is_empty()
+                                        && current_running_stages.contains(
+                                            &(fetch_partiton_error.map_stage_id as usize),
+                                        )
+                                        && !unsolved_stage
+                                            .last_attempt_failure_reasons
+                                            .contains(
+                                                &fetch_partiton_error.executor_id,
+                                            ) =>
+                                    {
+                                        should_ignore = false;
+                                        unsolved_stage
+                                            .last_attempt_failure_reasons
+                                            .insert(
+                                                fetch_partiton_error.executor_id.clone(),
+                                            );
+                                        let map_stage_id =
+                                            fetch_partiton_error.map_stage_id as usize;
+                                        let map_partition_id = fetch_partiton_error
+                                            .map_partition_id
+                                            as usize;
+                                        let executor_id =
+                                            fetch_partiton_error.executor_id;
+                                        let removed_map_partitions = unsolved_stage
+                                            .remove_input_partitions(
+                                                map_stage_id,
+                                                map_partition_id,
+                                                &executor_id,
+                                            )?;
+
+                                        let missing_inputs = reset_running_stages
+                                            .entry(map_stage_id)
+                                            .or_insert_with(HashSet::new);
+                                        missing_inputs.extend(removed_map_partitions);
+                                        warn!("Need to reset the current running Stage {} due to late come FetchPartitionError from its parent stage {} of task {}",
+                                                    map_stage_id, stage_id, task_identity);
+
+                                        // If the previous other task updates had already mark the map stage success, need to remove it.
+                                        if successful_stages.contains(&map_stage_id) {
+                                            successful_stages.remove(&map_stage_id);
+                                        }
+                                        if resolved_stages.contains(&stage_id) {
+                                            resolved_stages.remove(&stage_id);
+                                        }
+                                    }
+                                    _ => {}
+                                }
+                            }
+                        }
+                        if should_ignore {
+                            warn!("Ignore TaskStatus update of task with TID {} as the Stage {}/{} is in UnResolved status", task_identity, job_id, stage_id);
+                        }
+                    }
                 } else {
                     warn!(
                         "Stage {}/{} is not in running when updating the status of tasks {:?}",
                         job_id,
                         stage_id,
-                        stage_task_statuses.into_iter().map(|task_status| task_status.task_id.map(|task_id| task_id.partition_id)).collect::<Vec<_>>(),
+                        stage_task_statuses.into_iter().map(|task_status| task_status.partition_id).collect::<Vec<_>>(),
                     );
                 }
             } else {
@@ -318,17 +551,154 @@ impl ExecutionGraph {
             }
         }
 
-        self.processing_stage_events(events)
+        // Update failed stage attempts back to self
+        for (stage_id, attempts) in failed_stage_attempts.iter() {
+            self.failed_stage_attempts
+                .insert(*stage_id, HashSet::from_iter(attempts.iter().copied()));
+        }
+
+        for (stage_id, missing_parts) in &resubmit_successful_stages {
+            if let Some(stage) = self.stages.get_mut(stage_id) {
+                if let ExecutionStage::Successful(success_stage) = stage {
+                    for partition in missing_parts {
+                        if *partition > success_stage.partitions {
+                            return Err(BallistaError::Internal(format!(
+                                "Invalid partition ID {} in map stage {}",
+                                *partition, stage_id
+                            )));
+                        }
+                        let task_info = &mut success_stage.task_infos[*partition];
+                        // Update the task info to failed
+                        task_info.task_status = task_status::Status::Failed(FailedTask {
+                            error: "FetchPartitionError in parent stage".to_owned(),
+                            retryable: true,
+                            count_to_failures: false,
+                            failed_reason: Some(FailedReason::ResultLost(ResultLost {})),
+                        });
+                    }
+                } else {
+                    warn!(
+                        "Stage {}/{} is not in Successful state when try to resubmit this stage. ",
+                        job_id,
+                        stage_id);
+                }
+            } else {
+                return Err(BallistaError::Internal(format!(
+                    "Invalid stage ID {} for job {}",
+                    stage_id, job_id
+                )));
+            }
+        }
+
+        for (stage_id, missing_parts) in &reset_running_stages {
+            if let Some(stage) = self.stages.get_mut(stage_id) {
+                if let ExecutionStage::Running(running_stage) = stage {
+                    for partition in missing_parts {
+                        if *partition > running_stage.partitions {
+                            return Err(BallistaError::Internal(format!(
+                                "Invalid partition ID {} in map stage {}",
+                                *partition, stage_id
+                            )));
+                        }
+                        running_stage.reset_task_info(*partition);
+                    }
+                } else {
+                    warn!(
+                        "Stage {}/{} is not in Running state when try to reset the running task. ",
+                        job_id,
+                        stage_id);
+                }
+            } else {
+                return Err(BallistaError::Internal(format!(
+                    "Invalid stage ID {} for job {}",
+                    stage_id, job_id
+                )));
+            }
+        }
+
+        self.processing_stages_update(UpdatedStages {
+            resolved_stages,
+            successful_stages,
+            failed_stages,
+            rollback_running_stages,
+            resubmit_successful_stages: resubmit_successful_stages
+                .keys()
+                .cloned()
+                .collect(),
+        })
+    }
+
+    /// Processing stage status update after task status changing
+    fn processing_stages_update(
+        &mut self,
+        updated_stages: UpdatedStages,
+    ) -> Result<Vec<QueryStageSchedulerEvent>> {
+        let job_id = self.job_id().to_owned();
+        let mut has_resolved = false;
+        let mut job_err_msg = "".to_owned();
+
+        for stage_id in updated_stages.resolved_stages {
+            self.resolve_stage(stage_id)?;
+            has_resolved = true;
+        }
+
+        for stage_id in updated_stages.successful_stages {
+            self.succeed_stage(stage_id);
+        }
+
+        // Fail the stage and also abort the job
+        for (stage_id, err_msg) in &updated_stages.failed_stages {
+            job_err_msg =
+                format!("Job failed due to stage {} failed: {}\n", stage_id, err_msg);
+        }
+
+        let mut events = vec![];
+        // Only handle the rollback logic when there are no failed stages
+        if updated_stages.failed_stages.is_empty() {
+            let mut running_tasks_to_cancel = vec![];
+            for (stage_id, failure_reasons) in updated_stages.rollback_running_stages {
+                let tasks = self.rollback_running_stage(stage_id, failure_reasons)?;
+                running_tasks_to_cancel.extend(tasks);
+            }
+
+            for stage_id in updated_stages.resubmit_successful_stages {
+                self.rerun_successful_stage(stage_id);
+            }
+
+            if !running_tasks_to_cancel.is_empty() {
+                events.push(QueryStageSchedulerEvent::CancelTasks(
+                    running_tasks_to_cancel,
+                ));
+            }
+        }
+
+        if !updated_stages.failed_stages.is_empty() {
+            info!("Job {} is failed", job_id);
+            self.fail_job(job_err_msg.clone());
+            events.push(QueryStageSchedulerEvent::JobRunningFailed(
+                job_id,
+                job_err_msg,
+            ));
+        } else if self.is_successful() {
+            // If this ExecutionGraph is successful, finish it
+            info!("Job {} is success, finalizing output partitions", job_id);
+            self.succeed_job()?;
+            events.push(QueryStageSchedulerEvent::JobFinished(job_id));
+        } else if has_resolved {
+            events.push(QueryStageSchedulerEvent::JobUpdated(job_id))
+        }
+        Ok(events)
     }
 
+    /// Return a Vec of resolvable stage ids
     fn update_stage_output_links(
         &mut self,
         stage_id: usize,
         is_completed: bool,
         locations: Vec<PartitionLocation>,
         output_links: Vec<usize>,
-    ) -> Result<Vec<StageEvent>> {
-        let mut ret = vec![];
+    ) -> Result<Vec<usize>> {
+        let mut resolved_stages = vec![];
         let job_id = &self.job_id;
         if output_links.is_empty() {
             // If `output_links` is empty, then this is a final stage
@@ -350,9 +720,7 @@ impl ExecutionGraph {
 
                         // If all input partitions are ready, we can resolve any UnresolvedShuffleExec in the parent stage plan
                         if linked_unresolved_stage.resolvable() {
-                            ret.push(StageEvent::StageResolved(
-                                linked_unresolved_stage.stage_id,
-                            ));
+                            resolved_stages.push(linked_unresolved_stage.stage_id);
                         }
                     } else {
                         return Err(BallistaError::Internal(format!(
@@ -368,12 +736,25 @@ impl ExecutionGraph {
                 }
             }
         }
+        Ok(resolved_stages)
+    }
 
-        Ok(ret)
+    /// Return all the currently running stage ids
+    pub fn running_stages(&self) -> Vec<usize> {
+        self.stages
+            .iter()
+            .filter_map(|(stage_id, stage)| {
+                if let ExecutionStage::Running(_running) = stage {
+                    Some(*stage_id)
+                } else {
+                    None
+                }
+            })
+            .collect::<Vec<_>>()
     }
 
     /// Return all currently running tasks along with the executor ID on which they are assigned
-    pub fn running_tasks(&self) -> Vec<(PartitionId, String)> {
+    pub fn running_tasks(&self) -> Vec<RunningTaskInfo> {
         self.stages
             .iter()
             .flat_map(|(_, stage)| {
@@ -381,22 +762,21 @@ impl ExecutionGraph {
                     stage
                         .running_tasks()
                         .into_iter()
-                        .map(|(stage_id, partition_id, executor_id)| {
-                            (
-                                PartitionId {
-                                    job_id: self.job_id.clone(),
-                                    stage_id,
-                                    partition_id,
-                                },
+                        .map(|(task_id, stage_id, partition_id, executor_id)| {
+                            RunningTaskInfo {
+                                task_id,
+                                job_id: self.job_id.clone(),
+                                stage_id,
+                                partition_id,
                                 executor_id,
-                            )
+                            }
                         })
-                        .collect::<Vec<(PartitionId, String)>>()
+                        .collect::<Vec<RunningTaskInfo>>()
                 } else {
                     vec![]
                 }
             })
-            .collect::<Vec<(PartitionId, String)>>()
+            .collect::<Vec<RunningTaskInfo>>()
     }
 
     /// Total number of tasks in this plan that are ready for scheduling
@@ -419,9 +799,36 @@ impl ExecutionGraph {
     /// available to the scheduler.
     /// If the task is not launched the status must be reset to allow the task to
     /// be scheduled elsewhere.
-    pub fn pop_next_task(&mut self, executor_id: &str) -> Result<Option<Task>> {
+    pub fn pop_next_task(
+        &mut self,
+        executor_id: &str,
+    ) -> Result<Option<TaskDescription>> {
+        if matches!(
+            self.status,
+            JobStatus {
+                status: Some(job_status::Status::Failed(_)),
+            }
+        ) {
+            warn!("Call pop_next_task on failed Job");
+            return Ok(None);
+        }
+
         let job_id = self.job_id.clone();
         let session_id = self.session_id.clone();
+
+        let find_candidate = self.stages.iter().any(|(_stage_id, stage)| {
+            if let ExecutionStage::Running(stage) = stage {
+                stage.available_tasks() > 0
+            } else {
+                false
+            }
+        });
+        let next_task_id = if find_candidate {
+            Some(self.next_task_id())
+        } else {
+            None
+        };
+
         let mut next_task = self.stages.iter_mut().find(|(_stage_id, stage)| {
             if let ExecutionStage::Running(stage) = stage {
                 stage.available_tasks() > 0
@@ -431,10 +838,10 @@ impl ExecutionGraph {
         }).map(|(stage_id, stage)| {
             if let ExecutionStage::Running(stage) = stage {
                 let (partition_id, _) = stage
-                    .task_statuses
+                    .task_infos
                     .iter()
                     .enumerate()
-                    .find(|(_partition, status)| status.is_none())
+                    .find(|(_partition, info)| info.is_none())
                     .ok_or_else(|| {
                         BallistaError::Internal(format!("Error getting next task for job {}: Stage {} is ready but has no pending tasks", job_id, stage_id))
                     })?;
@@ -445,14 +852,33 @@ impl ExecutionGraph {
                     partition_id,
                 };
 
-                // Set the status to Running
-                stage.task_statuses[partition_id] = Some(task_status::Status::Running(RunningTask {
-                    executor_id: executor_id.to_owned()
-                }));
+                let task_id = next_task_id.unwrap();
+                let task_attempt = stage.task_failure_numbers[partition_id];
+                let task_info = TaskInfo {
+                    task_id,
+                    scheduled_time: SystemTime::now()
+                    .duration_since(UNIX_EPOCH)
+                    .unwrap()
+                    .as_millis(),
+                    // Those times will be updated when the task finish
+                    launch_time: 0,
+                    start_exec_time: 0,
+                    end_exec_time: 0,
+                    finish_time: 0,
+                    task_status: task_status::Status::Running(RunningTask {
+                        executor_id: executor_id.to_owned()
+                    }),
+                };
+
+                // Set the task info to Running for new task
+                stage.task_infos[partition_id] = Some(task_info);
 
-                Ok(Task {
+                Ok(TaskDescription {
                     session_id,
                     partition,
+                    stage_attempt_num: stage.stage_attempt_num,
+                    task_id,
+                    task_attempt,
                     plan: stage.plan.clone(),
                     output_partitioning: stage.output_partitioning.clone(),
                 })
@@ -478,47 +904,47 @@ impl ExecutionGraph {
         self.status = status;
     }
 
-    /// Reset the status for the given task. This should be called is a task failed to
-    /// launch and it needs to be returned to the set of available tasks and be
-    /// re-scheduled.
-    pub fn reset_task_status(&mut self, task: Task) {
-        let stage_id = task.partition.stage_id;
-        let partition = task.partition.partition_id;
-
-        if let Some(ExecutionStage::Running(stage)) = self.stages.get_mut(&stage_id) {
-            stage.task_statuses[partition] = None;
-        }
-    }
-
     pub fn output_locations(&self) -> Vec<PartitionLocation> {
         self.output_locations.clone()
     }
 
-    /// Reset running and completed stages on a given executor
-    /// This will first check the unresolved/resolved/running stages and reset the running tasks and completed tasks.
-    /// Then it will check the completed stage and whether there are running parent stages need to read shuffle from it.
-    /// If yes, reset the complete tasks and roll back the resolved shuffle recursively.
+    /// Reset running and successful stages on a given executor
+    /// This will first check the unresolved/resolved/running stages and reset the running tasks and successful tasks.
+    /// Then it will check the successful stage and whether there are running parent stages need to read shuffle from it.
+    /// If yes, reset the successful tasks and roll back the resolved shuffle recursively.
     ///
-    /// Returns the reset stage ids
-    pub fn reset_stages(&mut self, executor_id: &str) -> Result<HashSet<usize>> {
+    /// Returns the reset stage ids and running tasks should be killed
+    pub fn reset_stages_on_lost_executor(
+        &mut self,
+        executor_id: &str,
+    ) -> Result<(HashSet<usize>, Vec<RunningTaskInfo>)> {
         let mut reset = HashSet::new();
+        let mut tasks_to_cancel = vec![];
         loop {
             let reset_stage = self.reset_stages_internal(executor_id)?;
-            if !reset_stage.is_empty() {
-                reset.extend(reset_stage.iter());
+            if !reset_stage.0.is_empty() {
+                reset.extend(reset_stage.0.iter());
+                tasks_to_cancel.extend(reset_stage.1)
             } else {
-                return Ok(reset);
+                return Ok((reset, tasks_to_cancel));
             }
         }
     }
 
-    fn reset_stages_internal(&mut self, executor_id: &str) -> Result<HashSet<usize>> {
-        let mut reset_stage = HashSet::new();
+    fn reset_stages_internal(
+        &mut self,
+        executor_id: &str,
+    ) -> Result<(HashSet<usize>, Vec<RunningTaskInfo>)> {
         let job_id = self.job_id.clone();
-        let mut stage_events = vec![];
+        // collect the input stages that need to resubmit
         let mut resubmit_inputs: HashSet<usize> = HashSet::new();
-        let mut empty_inputs: HashMap<usize, StageOutput> = HashMap::new();
 
+        let mut reset_running_stage = HashSet::new();
+        let mut rollback_resolved_stages = HashSet::new();
+        let mut rollback_running_stages = HashSet::new();
+        let mut resubmit_successful_stages = HashSet::new();
+
+        let mut empty_inputs: HashMap<usize, StageOutput> = HashMap::new();
         // check the unresolved, resolved and running stages
         self.stages
             .iter_mut()
@@ -537,7 +963,7 @@ impl ExecutionGraph {
                         "Reset {} tasks for running job/stage {}/{} on lost Executor {}",
                         reset, job_id, stage_id, executor_id
                         );
-                            reset_stage.insert(*stage_id);
+                            reset_running_stage.insert(*stage_id);
                         }
                         &mut stage.inputs
                     }
@@ -545,25 +971,15 @@ impl ExecutionGraph {
                 };
 
                 // For each stage input, check whether there are input locations match that executor
-                // and calculate the resubmit input stages if the input stages are completed.
+                // and calculate the resubmit input stages if the input stages are successful.
                 let mut rollback_stage = false;
                 stage_inputs.iter_mut().for_each(|(input_stage_id, stage_output)| {
                     let mut match_found = false;
                     stage_output.partition_locations.iter_mut().for_each(
                         |(_partition, locs)| {
-                            let indexes = locs
-                                .iter()
-                                .enumerate()
-                                .filter_map(|(idx, loc)| {
-                                    (loc.executor_meta.id == executor_id).then_some(idx)
-                                })
-                                .collect::<Vec<_>>();
-
-                            // remove the matched partition locations
-                            if !indexes.is_empty() {
-                                for idx in &indexes {
-                                    locs.remove(*idx);
-                                }
+                            let before_len = locs.len();
+                            locs.retain(|loc| loc.executor_meta.id != executor_id);
+                            if locs.len() < before_len {
                                 match_found = true;
                             }
                         },
@@ -578,32 +994,31 @@ impl ExecutionGraph {
                 if rollback_stage {
                     match stage {
                         ExecutionStage::Resolved(_) => {
-                            stage_events.push(StageEvent::RollBackResolvedStage(*stage_id));
+                            rollback_resolved_stages.insert(*stage_id);
                             warn!(
                             "Roll back resolved job/stage {}/{} and change ShuffleReaderExec back to UnresolvedShuffleExec",
                             job_id, stage_id);
-                            reset_stage.insert(*stage_id);
+
                         },
                         ExecutionStage::Running(_) => {
-                            stage_events.push(StageEvent::RollBackRunningStage(*stage_id));
+                            rollback_running_stages.insert(*stage_id);
                             warn!(
                             "Roll back running job/stage {}/{} and change ShuffleReaderExec back to UnresolvedShuffleExec",
                             job_id, stage_id);
-                            reset_stage.insert(*stage_id);
                         },
                         _ => {},
                     }
                 }
             });
 
-        // check and reset the complete stages
+        // check and reset the successful stages
         if !resubmit_inputs.is_empty() {
             self.stages
                 .iter_mut()
                 .filter(|(stage_id, _stage)| resubmit_inputs.contains(stage_id))
                 .filter_map(|(_stage_id, stage)| {
-                    if let ExecutionStage::Completed(completed) = stage {
-                        Some(completed)
+                    if let ExecutionStage::Successful(success) = stage {
+                        Some(success)
                     } else {
                         None
                     }
@@ -611,79 +1026,42 @@ impl ExecutionGraph {
                 .for_each(|stage| {
                     let reset = stage.reset_tasks(executor_id);
                     if reset > 0 {
-                        stage_events
-                            .push(StageEvent::ReRunCompletedStage(stage.stage_id));
-                        reset_stage.insert(stage.stage_id);
+                        resubmit_successful_stages.insert(stage.stage_id);
                         warn!(
-                            "Reset {} tasks for completed job/stage {}/{} on lost Executor {}",
+                            "Reset {} tasks for successful job/stage {}/{} on lost Executor {}",
                             reset, job_id, stage.stage_id, executor_id
                         )
                     }
                 });
         }
-        self.processing_stage_events(stage_events)?;
-        Ok(reset_stage)
-    }
 
-    /// Processing stage events for stage state changing
-    pub fn processing_stage_events(
-        &mut self,
-        events: Vec<StageEvent>,
-    ) -> Result<Option<QueryStageSchedulerEvent>> {
-        let mut has_resolved = false;
-        let mut job_err_msg = "".to_owned();
-        for event in events {
-            match event {
-                StageEvent::StageResolved(stage_id) => {
-                    self.resolve_stage(stage_id)?;
-                    has_resolved = true;
-                }
-                StageEvent::StageCompleted(stage_id) => {
-                    self.complete_stage(stage_id);
-                }
-                StageEvent::StageFailed(stage_id, err_msg) => {
-                    job_err_msg = format!("{}{}\n", job_err_msg, &err_msg);
-                    self.fail_stage(stage_id, err_msg);
-                }
-                StageEvent::RollBackRunningStage(stage_id) => {
-                    self.rollback_running_stage(stage_id)?;
-                }
-                StageEvent::RollBackResolvedStage(stage_id) => {
-                    self.rollback_resolved_stage(stage_id)?;
-                }
-                StageEvent::ReRunCompletedStage(stage_id) => {
-                    self.rerun_completed_stage(stage_id);
-                }
-            }
+        for stage_id in rollback_resolved_stages.iter() {
+            self.rollback_resolved_stage(*stage_id)?;
         }
 
-        let event = if !job_err_msg.is_empty() {
-            // If this ExecutionGraph is complete, fail it
-            info!("Job {} is failed", self.job_id());
-            self.fail_job(job_err_msg);
+        let mut all_running_tasks = vec![];
+        for stage_id in rollback_running_stages.iter() {
+            let tasks = self.rollback_running_stage(
+                *stage_id,
+                HashSet::from([executor_id.to_owned()]),
+            )?;
+            all_running_tasks.extend(tasks);
+        }
 
-            Some(QueryStageSchedulerEvent::JobRunningFailed(
-                self.job_id.clone(),
-            ))
-        } else if self.complete() {
-            // If this ExecutionGraph is complete, finalize it
-            info!(
-                "Job {} is complete, finalizing output partitions",
-                self.job_id()
-            );
-            self.complete_job()?;
-            Some(QueryStageSchedulerEvent::JobFinished(self.job_id.clone()))
-        } else if has_resolved {
-            Some(QueryStageSchedulerEvent::JobUpdated(self.job_id.clone()))
-        } else {
-            None
-        };
+        for stage_id in resubmit_successful_stages.iter() {
+            self.rerun_successful_stage(*stage_id);
+        }
 
-        Ok(event)
+        let mut reset_stage = HashSet::new();
+        reset_stage.extend(reset_running_stage);
+        reset_stage.extend(rollback_resolved_stages);
+        reset_stage.extend(rollback_running_stages);
+        reset_stage.extend(resubmit_successful_stages);
+        Ok((reset_stage, all_running_tasks))
     }
 
     /// Convert unresolved stage to be resolved
-    fn resolve_stage(&mut self, stage_id: usize) -> Result<bool> {
+    pub fn resolve_stage(&mut self, stage_id: usize) -> Result<bool> {
         if let Some(ExecutionStage::UnResolved(stage)) = self.stages.remove(&stage_id) {
             self.stages
                 .insert(stage_id, ExecutionStage::Resolved(stage.to_resolved()?));
@@ -698,15 +1076,16 @@ impl ExecutionGraph {
         }
     }
 
-    /// Convert running stage to be completed
-    fn complete_stage(&mut self, stage_id: usize) -> bool {
+    /// Convert running stage to be successful
+    pub fn succeed_stage(&mut self, stage_id: usize) -> bool {
         if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
             self.stages
-                .insert(stage_id, ExecutionStage::Completed(stage.to_completed()));
+                .insert(stage_id, ExecutionStage::Successful(stage.to_successful()));
+            self.clear_stage_failure(stage_id);
             true
         } else {
             warn!(
-                "Fail to find a running stage {}/{} to complete",
+                "Fail to find a running stage {}/{} to make it success",
                 self.job_id(),
                 stage_id
             );
@@ -715,13 +1094,13 @@ impl ExecutionGraph {
     }
 
     /// Convert running stage to be failed
-    fn fail_stage(&mut self, stage_id: usize, err_msg: String) -> bool {
+    pub fn fail_stage(&mut self, stage_id: usize, err_msg: String) -> bool {
         if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
             self.stages
                 .insert(stage_id, ExecutionStage::Failed(stage.to_failed(err_msg)));
             true
         } else {
-            warn!(
+            info!(
                 "Fail to find a running stage {}/{} to fail",
                 self.job_id(),
                 stage_id
@@ -730,24 +1109,44 @@ impl ExecutionGraph {
         }
     }
 
-    /// Convert running stage to be unresolved
-    fn rollback_running_stage(&mut self, stage_id: usize) -> Result<bool> {
+    /// Convert running stage to be unresolved,
+    /// Returns a Vec of RunningTaskInfo for running tasks in this stage.
+    pub fn rollback_running_stage(
+        &mut self,
+        stage_id: usize,
+        failure_reasons: HashSet<String>,
+    ) -> Result<Vec<RunningTaskInfo>> {
         if let Some(ExecutionStage::Running(stage)) = self.stages.remove(&stage_id) {
-            self.stages
-                .insert(stage_id, ExecutionStage::UnResolved(stage.to_unresolved()?));
-            Ok(true)
+            let running_tasks = stage
+                .running_tasks()
+                .into_iter()
+                .map(
+                    |(task_id, stage_id, partition_id, executor_id)| RunningTaskInfo {
+                        task_id,
+                        job_id: self.job_id.clone(),
+                        stage_id,
+                        partition_id,
+                        executor_id,
+                    },
+                )
+                .collect();
+            self.stages.insert(
+                stage_id,
+                ExecutionStage::UnResolved(stage.to_unresolved(failure_reasons)?),
+            );
+            Ok(running_tasks)
         } else {
             warn!(
                 "Fail to find a running stage {}/{} to rollback",
                 self.job_id(),
                 stage_id
             );
-            Ok(false)
+            Ok(vec![])
         }
     }
 
     /// Convert resolved stage to be unresolved
-    fn rollback_resolved_stage(&mut self, stage_id: usize) -> Result<bool> {
+    pub fn rollback_resolved_stage(&mut self, stage_id: usize) -> Result<bool> {
         if let Some(ExecutionStage::Resolved(stage)) = self.stages.remove(&stage_id) {
             self.stages
                 .insert(stage_id, ExecutionStage::UnResolved(stage.to_unresolved()?));
@@ -762,15 +1161,15 @@ impl ExecutionGraph {
         }
     }
 
-    /// Convert completed stage to be running
-    fn rerun_completed_stage(&mut self, stage_id: usize) -> bool {
-        if let Some(ExecutionStage::Completed(stage)) = self.stages.remove(&stage_id) {
+    /// Convert successful stage to be running
+    pub fn rerun_successful_stage(&mut self, stage_id: usize) -> bool {
+        if let Some(ExecutionStage::Successful(stage)) = self.stages.remove(&stage_id) {
             self.stages
                 .insert(stage_id, ExecutionStage::Running(stage.to_running()));
             true
         } else {
             warn!(
-                "Fail to find a completed stage {}/{} to rerun",
+                "Fail to find a successful stage {}/{} to rerun",
                 self.job_id(),
                 stage_id
             );
@@ -785,9 +1184,9 @@ impl ExecutionGraph {
         };
     }
 
-    /// finalize job as completed
-    fn complete_job(&mut self) -> Result<()> {
-        if !self.complete() {
+    /// Mark the job success
+    pub fn succeed_job(&mut self) -> Result<()> {
+        if !self.is_successful() {
             return Err(BallistaError::Internal(format!(
                 "Attempt to finalize an incomplete job {}",
                 self.job_id()
@@ -801,7 +1200,7 @@ impl ExecutionGraph {
             .collect::<Result<Vec<_>>>()?;
 
         self.status = JobStatus {
-            status: Some(job_status::Status::Completed(CompletedJob {
+            status: Some(job_status::Status::Successful(SuccessfulJob {
                 partition_location,
             })),
         };
@@ -809,6 +1208,11 @@ impl ExecutionGraph {
         Ok(())
     }
 
+    /// Clear the stage failure count for this stage if the stage is finally success
+    fn clear_stage_failure(&mut self, stage_id: usize) {
+        self.failed_stage_attempts.remove(&stage_id);
+    }
+
     pub(crate) async fn decode_execution_graph<
         T: 'static + AsLogicalPlan,
         U: 'static + AsExecutionPlan,
@@ -832,10 +1236,10 @@ impl ExecutionGraph {
                         ResolvedStage::decode(stage, codec, session_ctx)?;
                     (stage.stage_id, ExecutionStage::Resolved(stage))
                 }
-                StageType::CompletedStage(stage) => {
-                    let stage: CompletedStage =
-                        CompletedStage::decode(stage, codec, session_ctx)?;
-                    (stage.stage_id, ExecutionStage::Completed(stage))
+                StageType::SuccessfulStage(stage) => {
+                    let stage: SuccessfulStage =
+                        SuccessfulStage::decode(stage, codec, session_ctx)?;
+                    (stage.stage_id, ExecutionStage::Successful(stage))
                 }
                 StageType::FailedStage(stage) => {
                     let stage: FailedStage =
@@ -853,6 +1257,22 @@ impl ExecutionGraph {
             .map(|loc| loc.try_into())
             .collect::<Result<Vec<_>>>()?;
 
+        let failed_stage_attempts = proto
+            .failed_attempts
+            .into_iter()
+            .map(|attempt| {
+                (
+                    attempt.stage_id as usize,
+                    HashSet::from_iter(
+                        attempt
+                            .stage_attempt_num
+                            .into_iter()
+                            .map(|num| num as usize),
+                    ),
+                )
+            })
+            .collect();
+
         Ok(ExecutionGraph {
             scheduler_id: proto.scheduler_id,
             job_id: proto.job_id,
@@ -865,6 +1285,8 @@ impl ExecutionGraph {
             stages,
             output_partitions: proto.output_partitions as usize,
             output_locations,
+            task_id_gen: proto.task_id_gen as usize,
+            failed_stage_attempts,
         })
     }
 
@@ -893,8 +1315,8 @@ impl ExecutionGraph {
                     ExecutionStage::Running(stage) => StageType::ResolvedStage(
                         ResolvedStage::encode(stage.to_resolved(), codec)?,
                     ),
-                    ExecutionStage::Completed(stage) => StageType::CompletedStage(
-                        CompletedStage::encode(job_id.clone(), stage, codec)?,
+                    ExecutionStage::Successful(stage) => StageType::SuccessfulStage(
+                        SuccessfulStage::encode(job_id.clone(), stage, codec)?,
                     ),
                     ExecutionStage::Failed(stage) => StageType::FailedStage(
                         FailedStage::encode(job_id.clone(), stage, codec)?,
@@ -912,6 +1334,21 @@ impl ExecutionGraph {
             .map(|loc| loc.try_into())
             .collect::<Result<Vec<_>>>()?;
 
+        let failed_attempts: Vec<protobuf::StageAttempts> = graph
+            .failed_stage_attempts
+            .into_iter()
+            .map(|(stage_id, attempts)| {
+                let stage_attempt_num = attempts
+                    .into_iter()
+                    .map(|num| num as u32)
+                    .collect::<Vec<_>>();
+                protobuf::StageAttempts {
+                    stage_id: stage_id as u32,
+                    stage_attempt_num,
+                }
+            })
+            .collect::<Vec<_>>();
+
         Ok(protobuf::ExecutionGraph {
             job_id: graph.job_id,
             session_id: graph.session_id,
@@ -920,6 +1357,8 @@ impl ExecutionGraph {
             output_partitions: graph.output_partitions as u64,
             output_locations,
             scheduler_id: graph.scheduler_id,
+            task_id_gen: graph.task_id_gen as u32,
+            failed_attempts,
         })
     }
 }
@@ -932,8 +1371,8 @@ impl Debug for ExecutionGraph {
             .map(|(_, stage)| format!("{:?}", stage))
             .collect::<Vec<String>>()
             .join("");
-        write!(f, "ExecutionGraph[job_id={}, session_id={}, available_tasks={}, complete={}]\n{}",
-               self.job_id, self.session_id, self.available_tasks(), self.complete(), stages)
+        write!(f, "ExecutionGraph[job_id={}, session_id={}, available_tasks={}, is_successful={}]\n{}",
+               self.job_id, self.session_id, self.available_tasks(), self.is_successful(), stages)
     }
 }
 
@@ -984,10 +1423,12 @@ impl ExecutionStageBuilder {
             let stage = if child_stages.is_empty() {
                 ExecutionStage::Resolved(ResolvedStage::new(
                     stage_id,
+                    0,
                     stage,
                     partitioning,
                     output_links,
                     HashMap::new(),
+                    HashSet::new(),
                 ))
             } else {
                 ExecutionStage::UnResolved(UnresolvedStage::new(
@@ -1041,36 +1482,32 @@ impl ExecutionPlanVisitor for ExecutionStageBuilder {
     }
 }
 
-#[derive(Clone)]
-pub enum StageEvent {
-    StageResolved(usize),
-    StageCompleted(usize),
-    StageFailed(usize, String),
-    RollBackRunningStage(usize),
-    RollBackResolvedStage(usize),
-    ReRunCompletedStage(usize),
-}
-
 /// Represents the basic unit of work for the Ballista executor. Will execute
 /// one partition of one stage on one task slot.
 #[derive(Clone)]
-pub struct Task {
+pub struct TaskDescription {
     pub session_id: String,
     pub partition: PartitionId,
+    pub stage_attempt_num: usize,
+    pub task_id: usize,
+    pub task_attempt: usize,
     pub plan: Arc<dyn ExecutionPlan>,
     pub output_partitioning: Option<Partitioning>,
 }
 
-impl Debug for Task {
+impl Debug for TaskDescription {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         let plan = DisplayableExecutionPlan::new(self.plan.as_ref()).indent();
         write!(
             f,
-            "Task[session_id: {}, job: {}, stage: {}, partition: {}]\n{}",
+            "TaskDescription[session_id: {},job: {}, stage: {}.{}, partition: {} task_id {}, task attempt {}]\n{}",
             self.session_id,
             self.partition.job_id,
             self.partition.stage_id,
+            self.stage_attempt_num,
             self.partition.partition_id,
+            self.task_id,
+            self.task_attempt,
             plan
         )
     }
@@ -1078,6 +1515,7 @@ impl Debug for Task {
 
 fn partition_to_location(
     job_id: &str,
+    map_partition_id: usize,
     stage_id: usize,
     executor: &ExecutorMetadata,
     shuffles: Vec<ShuffleWritePartition>,
@@ -1085,6 +1523,7 @@ fn partition_to_location(
     shuffles
         .into_iter()
         .map(|shuffle| PartitionLocation {
+            map_partition_id,
             partition_id: PartitionId {
                 job_id: job_id.to_owned(),
                 stage_id,
@@ -1103,20 +1542,25 @@ fn partition_to_location(
 
 #[cfg(test)]
 mod test {
+    use std::collections::HashSet;
     use std::sync::Arc;
 
     use datafusion::arrow::datatypes::{DataType, Field, Schema};
-    use datafusion::logical_expr::{col, sum, Expr};
+    use datafusion::logical_expr::{col, count, sum, Expr};
     use datafusion::logical_plan::JoinType;
     use datafusion::physical_plan::display::DisplayableExecutionPlan;
     use datafusion::prelude::{SessionConfig, SessionContext};
     use datafusion::test_util::scan_empty;
 
+    use crate::scheduler_server::event::QueryStageSchedulerEvent;
     use ballista_core::error::Result;
-    use ballista_core::serde::protobuf::{self, job_status, task_status, TaskStatus};
+    use ballista_core::serde::protobuf::{
+        self, failed_task, job_status, task_status, ExecutionError, FailedTask,
+        FetchPartitionError, IoError, JobStatus, TaskKilled, TaskStatus,
+    };
     use ballista_core::serde::scheduler::{ExecutorMetadata, ExecutorSpecification};
 
-    use crate::state::execution_graph::{ExecutionGraph, Task};
+    use crate::state::execution_graph::{ExecutionGraph, TaskDescription};
 
     #[tokio::test]
     async fn test_drain_tasks() -> Result<()> {
@@ -1126,14 +1570,17 @@ mod test {
 
         drain_tasks(&mut agg_graph)?;
 
-        assert!(agg_graph.complete(), "Failed to complete aggregation plan");
+        assert!(
+            agg_graph.is_successful(),
+            "Failed to complete aggregation plan"
+        );
 
         let mut coalesce_graph = test_coalesce_plan(4).await;
 
         drain_tasks(&mut coalesce_graph)?;
 
         assert!(
-            coalesce_graph.complete(),
+            coalesce_graph.is_successful(),
             "Failed to complete coalesce plan"
         );
 
@@ -1143,7 +1590,7 @@ mod test {
 
         println!("{:?}", join_graph);
 
-        assert!(join_graph.complete(), "Failed to complete join plan");
+        assert!(join_graph.is_successful(), "Failed to complete join plan");
 
         let mut union_all_graph = test_union_all_plan(4).await;
 
@@ -1151,7 +1598,10 @@ mod test {
 
         println!("{:?}", union_all_graph);
 
-        assert!(union_all_graph.complete(), "Failed to complete union plan");
+        assert!(
+            union_all_graph.is_successful(),
+            "Failed to complete union plan"
+        );
 
         let mut union_graph = test_union_plan(4).await;
 
@@ -1159,7 +1609,7 @@ mod test {
 
         println!("{:?}", union_graph);
 
-        assert!(union_graph.complete(), "Failed to complete union plan");
+        assert!(union_graph.is_successful(), "Failed to complete union plan");
 
         Ok(())
     }
@@ -1175,7 +1625,7 @@ mod test {
         assert!(matches!(
             status,
             protobuf::JobStatus {
-                status: Some(job_status::Status::Completed(_))
+                status: Some(job_status::Status::Successful(_))
             }
         ));
 
@@ -1191,7 +1641,7 @@ mod test {
     }
 
     #[tokio::test]
-    async fn test_reset_completed_stage() -> Result<()> {
+    async fn test_reset_completed_stage_executor_lost() -> Result<()> {
         let executor1 = mock_executor("executor-id1".to_string());
         let executor2 = mock_executor("executor-id2".to_string());
         let mut join_graph = test_join_plan(4).await;
@@ -1208,13 +1658,13 @@ mod test {
         // Complete the first stage
         if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
             let task_status = mock_completed_task(task, &executor1.id);
-            join_graph.update_task_status(&executor1, vec![task_status])?;
+            join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
         }
 
         // Complete the second stage
         if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
             let task_status = mock_completed_task(task, &executor2.id);
-            join_graph.update_task_status(&executor2, vec![task_status])?;
+            join_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
         }
 
         join_graph.revive();
@@ -1224,25 +1674,25 @@ mod test {
         // Complete 1 task
         if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
             let task_status = mock_completed_task(task, &executor1.id);
-            join_graph.update_task_status(&executor1, vec![task_status])?;
+            join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
         }
         // Mock 1 running task
         let _task = join_graph.pop_next_task(&executor1.id)?;
 
-        let reset = join_graph.reset_stages(&executor1.id)?;
+        let reset = join_graph.reset_stages_on_lost_executor(&executor1.id)?;
 
         // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running
-        assert_eq!(reset.len(), 2);
+        assert_eq!(reset.0.len(), 2);
         assert_eq!(join_graph.available_tasks(), 1);
 
         drain_tasks(&mut join_graph)?;
-        assert!(join_graph.complete(), "Failed to complete join plan");
+        assert!(join_graph.is_successful(), "Failed to complete join plan");
 
         Ok(())
     }
 
     #[tokio::test]
-    async fn test_reset_resolved_stage() -> Result<()> {
+    async fn test_reset_resolved_stage_executor_lost() -> Result<()> {
         let executor1 = mock_executor("executor-id1".to_string());
         let executor2 = mock_executor("executor-id2".to_string());
         let mut join_graph = test_join_plan(4).await;
@@ -1259,26 +1709,26 @@ mod test {
         // Complete the first stage
         if let Some(task) = join_graph.pop_next_task(&executor1.id)? {
             let task_status = mock_completed_task(task, &executor1.id);
-            join_graph.update_task_status(&executor1, vec![task_status])?;
+            join_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
         }
 
         // Complete the second stage
         if let Some(task) = join_graph.pop_next_task(&executor2.id)? {
             let task_status = mock_completed_task(task, &executor2.id);
-            join_graph.update_task_status(&executor2, vec![task_status])?;
+            join_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
         }
 
         // There are 0 tasks pending schedule now
         assert_eq!(join_graph.available_tasks(), 0);
 
-        let reset = join_graph.reset_stages(&executor1.id)?;
+        let reset = join_graph.reset_stages_on_lost_executor(&executor1.id)?;
 
         // Two stages were reset, 1 Resolved stage rollback to Unresolved and 1 Completed stage move to Running
-        assert_eq!(reset.len(), 2);
+        assert_eq!(reset.0.len(), 2);
         assert_eq!(join_graph.available_tasks(), 1);
 
         drain_tasks(&mut join_graph)?;
-        assert!(join_graph.complete(), "Failed to complete join plan");
+        assert!(join_graph.is_successful(), "Failed to complete join plan");
 
         Ok(())
     }
@@ -1301,19 +1751,19 @@ mod test {
         // Complete the first stage
         if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
             let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status])?;
+            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
         }
 
         // 1st task in the second stage
         if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
             let task_status = mock_completed_task(task, &executor2.id);
-            agg_graph.update_task_status(&executor2, vec![task_status])?;
+            agg_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
         }
 
         // 2rd task in the second stage
         if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
             let task_status = mock_completed_task(task, &executor1.id);
-            agg_graph.update_task_status(&executor1, vec![task_status])?;
+            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
         }
 
         // 3rd task in the second stage, scheduled but not completed
@@ -1322,73 +1772,1010 @@ mod test {
         // There is 1 task pending schedule now
         assert_eq!(agg_graph.available_tasks(), 1);
 
-        let reset = agg_graph.reset_stages(&executor1.id)?;
+        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
 
         // 3rd task status update comes later.
         let task_status = mock_completed_task(task.unwrap(), &executor1.id);
-        agg_graph.update_task_status(&executor1, vec![task_status])?;
+        agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
 
         // Two stages were reset, 1 Running stage rollback to Unresolved and 1 Completed stage move to Running
-        assert_eq!(reset.len(), 2);
+        assert_eq!(reset.0.len(), 2);
         assert_eq!(agg_graph.available_tasks(), 1);
 
         // Call the reset again
-        let reset = agg_graph.reset_stages(&executor1.id)?;
-        assert_eq!(reset.len(), 0);
+        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
+        assert_eq!(reset.0.len(), 0);
         assert_eq!(agg_graph.available_tasks(), 1);
 
         drain_tasks(&mut agg_graph)?;
-        assert!(agg_graph.complete(), "Failed to complete agg plan");
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
 
         Ok(())
     }
 
-    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
-        let executor = mock_executor("executor-id1".to_string());
-        while let Some(task) = graph.pop_next_task(&executor.id)? {
-            let task_status = mock_completed_task(task, &executor.id);
-            graph.update_task_status(&executor, vec![task_status])?;
+    #[tokio::test]
+    async fn test_do_not_retry_killed_task() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        // Complete the first stage
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
         }
 
-        Ok(())
-    }
+        // 1st task in the second stage
+        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor2.id);
 
-    async fn test_aggregation_plan(partition: usize) -> ExecutionGraph {
-        let config = SessionConfig::new().with_target_partitions(partition);
-        let ctx = Arc::new(SessionContext::with_config(config));
+        // 2rd task in the second stage
+        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status2 = mock_failed_task(
+            task2,
+            FailedTask {
+                error: "Killed".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::TaskKilled(TaskKilled {})),
+            },
+        );
 
-        let schema = Schema::new(vec![
-            Field::new("id", DataType::Utf8, false),
-            Field::new("gmv", DataType::UInt64, false),
-        ]);
+        agg_graph.update_task_status(
+            &executor2,
+            vec![task_status1, task_status2],
+            4,
+            4,
+        )?;
 
-        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
-            .unwrap()
-            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
-            .unwrap()
-            .build()
-            .unwrap();
+        // TODO the JobStatus is not 'Running' here, no place to mark it to 'Running' in current code base.
+        assert!(
+            matches!(
+                agg_graph.status,
+                JobStatus {
+                    status: Some(job_status::Status::Queued(_))
+                }
+            ),
+            "Expected job status to be running"
+        );
 
-        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+        assert_eq!(agg_graph.available_tasks(), 2);
+        drain_tasks(&mut agg_graph)?;
+        assert_eq!(agg_graph.available_tasks(), 0);
 
-        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+        assert!(
+            !agg_graph.is_successful(),
+            "Expected the agg graph can not complete"
+        );
+        Ok(())
+    }
 
-        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+    #[tokio::test]
+    async fn test_max_task_failed_count() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(2).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
 
-        ExecutionGraph::new("localhost:50050", "job", "session", plan).unwrap()
-    }
+        // Complete the first stage
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
 
-    async fn test_coalesce_plan(partition: usize) -> ExecutionGraph {
-        let config = SessionConfig::new().with_target_partitions(partition);
-        let ctx = Arc::new(SessionContext::with_config(config));
+        // 1st task in the second stage
+        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor2.id);
+
+        // 2rd task in the second stage, failed due to IOError
+        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status2 = mock_failed_task(
+            task2.clone(),
+            FailedTask {
+                error: "IOError".to_string(),
+                retryable: true,
+                count_to_failures: true,
+                failed_reason: Some(failed_task::FailedReason::IoError(IoError {})),
+            },
+        );
 
-        let schema = Schema::new(vec![
-            Field::new("id", DataType::Utf8, false),
-            Field::new("gmv", DataType::UInt64, false),
-        ]);
+        agg_graph.update_task_status(
+            &executor2,
+            vec![task_status1, task_status2],
+            4,
+            4,
+        )?;
 
-        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
-            .unwrap()
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        let mut last_attempt = 0;
+        // 2rd task's attempts
+        for attempt in 1..5 {
+            if let Some(task2_attempt) = agg_graph.pop_next_task(&executor2.id)? {
+                assert_eq!(
+                    task2_attempt.partition.partition_id,
+                    task2.partition.partition_id
+                );
+                assert_eq!(task2_attempt.task_attempt, attempt);
+                last_attempt = task2_attempt.task_attempt;
+                let task_status = mock_failed_task(
+                    task2_attempt.clone(),
+                    FailedTask {
+                        error: "IOError".to_string(),
+                        retryable: true,
+                        count_to_failures: true,
+                        failed_reason: Some(failed_task::FailedReason::IoError(
+                            IoError {},
+                        )),
+                    },
+                );
+                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
+            }
+        }
+
+        assert!(
+            matches!(
+                agg_graph.status,
+                JobStatus {
+                    status: Some(job_status::Status::Failed(_))
+                }
+            ),
+            "Expected job status to be Failed"
+        );
+
+        assert_eq!(last_attempt, 3);
+
+        let failure_reason = format!("{:?}", agg_graph.status);
+        assert!(failure_reason.contains("Task 1 in Stage 2 failed 4 times, fail the stage, most recent failure reason"));
+        assert!(failure_reason.contains("IOError"));
+        assert!(!agg_graph.is_successful());
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_long_delayed_failed_task_after_executor_lost() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
+        }
+
+        // 1st task in the Stage 2
+        if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+            let task_status = mock_completed_task(task, &executor2.id);
+            agg_graph.update_task_status(&executor2, vec![task_status], 1, 1)?;
+        }
+
+        // 2rd task in the Stage 2
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
+        }
+
+        // 3rd task in the Stage 2, scheduled on executor 2 but not completed
+        let task = agg_graph.pop_next_task(&executor2.id)?;
+
+        // There is 1 task pending schedule now
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        // executor 1 lost
+        let reset = agg_graph.reset_stages_on_lost_executor(&executor1.id)?;
+
+        // Two stages were reset, Stage 2 rollback to Unresolved and Stage 1 move to Running
+        assert_eq!(reset.0.len(), 2);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        // Complete the Stage 1 again
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 1, 1)?;
+        }
+
+        // Stage 2 move to Running
+        agg_graph.revive();
+        assert_eq!(agg_graph.available_tasks(), 4);
+
+        // 3rd task in Stage 2 update comes very late due to runtime execution error.
+        let task_status = mock_failed_task(
+            task.unwrap(),
+            FailedTask {
+                error: "ExecutionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::ExecutionError(
+                    ExecutionError {},
+                )),
+            },
+        );
+
+        // This long delayed failed task should not failure the stage/job and should not trigger any query stage events
+        let query_stage_events =
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        assert!(query_stage_events.is_empty());
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_normal_fetch_failure() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // 1st task in the Stage 2
+        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor2.id);
+
+        // 2nd task in the Stage 2, failed due to FetchPartitionError
+        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status2 = mock_failed_task(
+            task2,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor1.id.clone(),
+                        map_stage_id: 1,
+                        map_partition_id: 0,
+                    },
+                )),
+            },
+        );
+
+        let mut running_task_count = 0;
+        while let Some(_task) = agg_graph.pop_next_task(&executor2.id)? {
+            running_task_count += 1;
+        }
+        assert_eq!(running_task_count, 2);
+
+        let stage_events = agg_graph.update_task_status(
+            &executor2,
+            vec![task_status1, task_status2],
+            4,
+            4,
+        )?;
+
+        assert_eq!(stage_events.len(), 1);
+        assert!(matches!(
+            stage_events[0],
+            QueryStageSchedulerEvent::CancelTasks(_)
+        ));
+
+        // Stage 1 is running
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 1);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_many_fetch_failures_in_one_stage() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let executor3 = mock_executor("executor-id3".to_string());
+        let mut agg_graph = test_two_aggregations_plan(8).await;
+
+        agg_graph.revive();
+        assert_eq!(agg_graph.stage_count(), 3);
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // Complete the Stage 2, 5 tasks run on executor_2 and 3 tasks run on executor_1
+        for _i in 0..5 {
+            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+                let task_status = mock_completed_task(task, &executor2.id);
+                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 3);
+        for _i in 0..3 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+
+        // Run Stage 3, 6 tasks failed due to FetchPartitionError on different map partitions on executor_2
+        let mut many_fetch_failure_status = vec![];
+        for part in 2..8 {
+            if let Some(task) = agg_graph.pop_next_task(&executor3.id)? {
+                let task_status = mock_failed_task(
+                    task,
+                    FailedTask {
+                        error: "FetchPartitionError".to_string(),
+                        retryable: false,
+                        count_to_failures: false,
+                        failed_reason: Some(
+                            failed_task::FailedReason::FetchPartitionError(
+                                FetchPartitionError {
+                                    executor_id: executor2.id.clone(),
+                                    map_stage_id: 2,
+                                    map_partition_id: part,
+                                },
+                            ),
+                        ),
+                    },
+                );
+                many_fetch_failure_status.push(task_status);
+            }
+        }
+        assert_eq!(many_fetch_failure_status.len(), 6);
+        agg_graph.update_task_status(&executor3, many_fetch_failure_status, 4, 4)?;
+
+        // The Running stage should be Stage 2 now
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 5);
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_many_consecutive_stage_fetch_failures() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        for attempt in 0..6 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+
+            // 1rd task in the Stage 2, failed due to FetchPartitionError
+            if let Some(task1) = agg_graph.pop_next_task(&executor2.id)? {
+                let task_status1 = mock_failed_task(
+                    task1.clone(),
+                    FailedTask {
+                        error: "FetchPartitionError".to_string(),
+                        retryable: false,
+                        count_to_failures: false,
+                        failed_reason: Some(
+                            failed_task::FailedReason::FetchPartitionError(
+                                FetchPartitionError {
+                                    executor_id: executor1.id.clone(),
+                                    map_stage_id: 1,
+                                    map_partition_id: 0,
+                                },
+                            ),
+                        ),
+                    },
+                );
+
+                let stage_events =
+                    agg_graph.update_task_status(&executor2, vec![task_status1], 4, 4)?;
+
+                if attempt < 3 {
+                    // No JobRunningFailed stage events
+                    assert_eq!(stage_events.len(), 0);
+                    // Stage 1 is running
+                    let running_stage = agg_graph.running_stages();
+                    assert_eq!(running_stage.len(), 1);
+                    assert_eq!(running_stage[0], 1);
+                    assert_eq!(agg_graph.available_tasks(), 1);
+                } else {
+                    // Job is failed after exceeds the max_stage_failures
+                    assert_eq!(stage_events.len(), 1);
+                    assert!(matches!(
+                        stage_events[0],
+                        QueryStageSchedulerEvent::JobRunningFailed(_, _)
+                    ));
+                    // Stage 2 is still running
+                    let running_stage = agg_graph.running_stages();
+                    assert_eq!(running_stage.len(), 1);
+                    assert_eq!(running_stage[0], 2);
+                }
+            }
+        }
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(!agg_graph.is_successful(), "Expect to fail the agg plan");
+
+        let failure_reason = format!("{:?}", agg_graph.status);
+        assert!(failure_reason.contains("Job failed due to stage 2 failed: Stage 2 has failed 4 times, most recent failure reason"));
+        assert!(failure_reason.contains("FetchPartitionError"));
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_long_delayed_fetch_failures() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let executor3 = mock_executor("executor-id3".to_string());
+        let mut agg_graph = test_two_aggregations_plan(8).await;
+
+        agg_graph.revive();
+        assert_eq!(agg_graph.stage_count(), 3);
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // Complete the Stage 2, 5 tasks run on executor_2, 2 tasks run on executor_1, 1 task runs on executor_3
+        for _i in 0..5 {
+            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+                let task_status = mock_completed_task(task, &executor2.id);
+                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        for _i in 0..2 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+
+        if let Some(task) = agg_graph.pop_next_task(&executor3.id)? {
+            let task_status = mock_completed_task(task, &executor3.id);
+            agg_graph.update_task_status(&executor3, vec![task_status], 4, 4)?;
+        }
+        assert_eq!(agg_graph.available_tasks(), 0);
+
+        //Run Stage 3
+        // 1st task scheduled
+        let task_1 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+        // 2nd task scheduled
+        let task_2 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+        // 3rd task scheduled
+        let task_3 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+        // 4th task scheduled
+        let task_4 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+        // 5th task scheduled
+        let task_5 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+
+        // Stage 3, 1st task failed due to FetchPartitionError(executor2)
+        let task_status_1 = mock_failed_task(
+            task_1,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor2.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 0,
+                    },
+                )),
+            },
+        );
+        agg_graph.update_task_status(&executor3, vec![task_status_1], 4, 4)?;
+
+        // The Running stage is Stage 2 now
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 5);
+
+        // Stage 3, 2nd task failed due to FetchPartitionError(executor2)
+        let task_status_2 = mock_failed_task(
+            task_2,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor2.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 1,
+                    },
+                )),
+            },
+        );
+        // This task update should be ignored
+        agg_graph.update_task_status(&executor3, vec![task_status_2], 4, 4)?;
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 5);
+
+        // Stage 3, 3rd task failed due to FetchPartitionError(executor1)
+        let task_status_3 = mock_failed_task(
+            task_3,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor1.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 1,
+                    },
+                )),
+            },
+        );
+        // This task update should be handled because it has a different failure reason
+        agg_graph.update_task_status(&executor3, vec![task_status_3], 4, 4)?;
+        // Running stage is still Stage 2, but available tasks changed to 7
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 7);
+
+        // Finish 4 tasks in Stage 2, to make some progress
+        for _i in 0..4 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        // Stage 3, 4th task failed due to FetchPartitionError(executor1)
+        let task_status_4 = mock_failed_task(
+            task_4,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor1.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 1,
+                    },
+                )),
+            },
+        );
+        // This task update should be ignored because the same failure reason is already handled
+        agg_graph.update_task_status(&executor3, vec![task_status_4], 4, 4)?;
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        // Finish the other 3 tasks in Stage 2
+        for _i in 0..3 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 0);
+
+        // Stage 3, the very long delayed 5th task failed due to FetchPartitionError(executor3)
+        // Although the failure reason is new, but this task should be ignored
+        // Because its map stage's new attempt is finished and this stage's new attempt is running
+        let task_status_5 = mock_failed_task(
+            task_5,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor3.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 1,
+                    },
+                )),
+            },
+        );
+        agg_graph.update_task_status(&executor3, vec![task_status_5], 4, 4)?;
+        // Stage 3's new attempt is running
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 3);
+        assert_eq!(agg_graph.available_tasks(), 8);
+
+        // There is one failed stage attempts: Stage 3. Stage 2 does not count to failed attempts
+        assert_eq!(agg_graph.failed_stage_attempts.len(), 1);
+        assert_eq!(
+            agg_graph.failed_stage_attempts.get(&3).cloned(),
+            Some(HashSet::from([0]))
+        );
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+        // Failed stage attempts are cleaned
+        assert_eq!(agg_graph.failed_stage_attempts.len(), 0);
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    // This test case covers a race condition in delayed fetch failure handling:
+    // TaskStatus of input stage's new attempt come together with the parent stage's delayed FetchFailure
+    async fn test_long_delayed_fetch_failures_race_condition() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let executor3 = mock_executor("executor-id3".to_string());
+        let mut agg_graph = test_two_aggregations_plan(8).await;
+
+        agg_graph.revive();
+        assert_eq!(agg_graph.stage_count(), 3);
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1
+        for _i in 0..5 {
+            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+                let task_status = mock_completed_task(task, &executor2.id);
+                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        for _i in 0..3 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 0);
+
+        // Run Stage 3
+        // 1st task scheduled
+        let task_1 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+        // 2nd task scheduled
+        let task_2 = agg_graph.pop_next_task(&executor3.id)?.unwrap();
+
+        // Stage 3, 1st task failed due to FetchPartitionError(executor2)
+        let task_status_1 = mock_failed_task(
+            task_1,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor2.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 0,
+                    },
+                )),
+            },
+        );
+        agg_graph.update_task_status(&executor3, vec![task_status_1], 4, 4)?;
+
+        // The Running stage is Stage 2 now
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 5);
+
+        // Complete the 5 tasks in Stage 2's new attempts
+        let mut task_status_vec = vec![];
+        for _i in 0..5 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                task_status_vec.push(mock_completed_task(task, &executor1.id))
+            }
+        }
+
+        // Stage 3, 2nd task failed due to FetchPartitionError(executor1)
+        let task_status_2 = mock_failed_task(
+            task_2,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor1.id.clone(),
+                        map_stage_id: 2,
+                        map_partition_id: 1,
+                    },
+                )),
+            },
+        );
+        task_status_vec.push(task_status_2);
+
+        // TaskStatus of Stage 2 come together with Stage 3 delayed FetchFailure update.
+        // The successful tasks from Stage 2 would try to succeed the Stage2 and the delayed fetch failure try to reset the TaskInfo
+        agg_graph.update_task_status(&executor3, task_status_vec, 4, 4)?;
+        //The Running stage is still Stage 2, 3 new pending tasks added due to FetchPartitionError(executor1)
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_fetch_failures_in_different_stages() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let executor3 = mock_executor("executor-id3".to_string());
+        let mut agg_graph = test_two_aggregations_plan(8).await;
+
+        agg_graph.revive();
+        assert_eq!(agg_graph.stage_count(), 3);
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // Complete the Stage 2, 5 tasks run on executor_2, 3 tasks run on executor_1
+        for _i in 0..5 {
+            if let Some(task) = agg_graph.pop_next_task(&executor2.id)? {
+                let task_status = mock_completed_task(task, &executor2.id);
+                agg_graph.update_task_status(&executor2, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 3);
+        for _i in 0..3 {
+            if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+                let task_status = mock_completed_task(task, &executor1.id);
+                agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+            }
+        }
+        assert_eq!(agg_graph.available_tasks(), 0);
+
+        // Run Stage 3
+        // 1rd task in the Stage 3, failed due to FetchPartitionError(executor1)
+        if let Some(task1) = agg_graph.pop_next_task(&executor3.id)? {
+            let task_status1 = mock_failed_task(
+                task1,
+                FailedTask {
+                    error: "FetchPartitionError".to_string(),
+                    retryable: false,
+                    count_to_failures: false,
+                    failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                        FetchPartitionError {
+                            executor_id: executor1.id.clone(),
+                            map_stage_id: 2,
+                            map_partition_id: 0,
+                        },
+                    )),
+                },
+            );
+
+            let _stage_events =
+                agg_graph.update_task_status(&executor3, vec![task_status1], 4, 4)?;
+        }
+        // The Running stage is Stage 2 now
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 2);
+        assert_eq!(agg_graph.available_tasks(), 3);
+
+        // 1rd task in the Stage 2's new attempt, failed due to FetchPartitionError(executor1)
+        if let Some(task1) = agg_graph.pop_next_task(&executor3.id)? {
+            let task_status1 = mock_failed_task(
+                task1,
+                FailedTask {
+                    error: "FetchPartitionError".to_string(),
+                    retryable: false,
+                    count_to_failures: false,
+                    failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                        FetchPartitionError {
+                            executor_id: executor1.id.clone(),
+                            map_stage_id: 1,
+                            map_partition_id: 0,
+                        },
+                    )),
+                },
+            );
+            let _stage_events =
+                agg_graph.update_task_status(&executor3, vec![task_status1], 4, 4)?;
+        }
+        // The Running stage is Stage 1 now
+        let running_stage = agg_graph.running_stages();
+        assert_eq!(running_stage.len(), 1);
+        assert_eq!(running_stage[0], 1);
+        assert_eq!(agg_graph.available_tasks(), 1);
+
+        // There are two failed stage attempts: Stage 2 and Stage 3
+        assert_eq!(agg_graph.failed_stage_attempts.len(), 2);
+        assert_eq!(
+            agg_graph.failed_stage_attempts.get(&2).cloned(),
+            Some(HashSet::from([1]))
+        );
+        assert_eq!(
+            agg_graph.failed_stage_attempts.get(&3).cloned(),
+            Some(HashSet::from([0]))
+        );
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(agg_graph.is_successful(), "Failed to complete agg plan");
+        assert_eq!(agg_graph.failed_stage_attempts.len(), 0);
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_fetch_failure_with_normal_task_failure() -> Result<()> {
+        let executor1 = mock_executor("executor-id1".to_string());
+        let executor2 = mock_executor("executor-id2".to_string());
+        let mut agg_graph = test_aggregation_plan(4).await;
+        // Call revive to move the leaf Resolved stages to Running
+        agg_graph.revive();
+
+        // Complete the Stage 1
+        if let Some(task) = agg_graph.pop_next_task(&executor1.id)? {
+            let task_status = mock_completed_task(task, &executor1.id);
+            agg_graph.update_task_status(&executor1, vec![task_status], 4, 4)?;
+        }
+
+        // 1st task in the Stage 2
+        let task1 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status1 = mock_completed_task(task1, &executor2.id);
+
+        // 2nd task in the Stage 2, failed due to FetchPartitionError
+        let task2 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status2 = mock_failed_task(
+            task2,
+            FailedTask {
+                error: "FetchPartitionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::FetchPartitionError(
+                    FetchPartitionError {
+                        executor_id: executor1.id.clone(),
+                        map_stage_id: 1,
+                        map_partition_id: 0,
+                    },
+                )),
+            },
+        );
+
+        // 3rd task in the Stage 2, failed due to ExecutionError
+        let task3 = agg_graph.pop_next_task(&executor2.id)?.unwrap();
+        let task_status3 = mock_failed_task(
+            task3,
+            FailedTask {
+                error: "ExecutionError".to_string(),
+                retryable: false,
+                count_to_failures: false,
+                failed_reason: Some(failed_task::FailedReason::ExecutionError(
+                    ExecutionError {},
+                )),
+            },
+        );
+
+        let stage_events = agg_graph.update_task_status(
+            &executor2,
+            vec![task_status1, task_status2, task_status3],
+            4,
+            4,
+        )?;
+
+        assert_eq!(stage_events.len(), 1);
+        assert!(matches!(
+            stage_events[0],
+            QueryStageSchedulerEvent::JobRunningFailed(_, _)
+        ));
+
+        drain_tasks(&mut agg_graph)?;
+        assert!(!agg_graph.is_successful(), "Expect to fail the agg plan");
+
+        let failure_reason = format!("{:?}", agg_graph.status);
+        assert!(failure_reason.contains("Job failed due to stage 2 failed"));
+        assert!(failure_reason.contains("ExecutionError"));
+
+        Ok(())
+    }
+
+    // #[tokio::test]
+    // async fn test_shuffle_files_should_cleaned_after_fetch_failure() -> Result<()> {
+    //     todo!()
+    // }
+
+    fn drain_tasks(graph: &mut ExecutionGraph) -> Result<()> {
+        let executor = mock_executor("executor-id1".to_string());
+        while let Some(task) = graph.pop_next_task(&executor.id)? {
+            let task_status = mock_completed_task(task, &executor.id);
+            graph.update_task_status(&executor, vec![task_status], 1, 1)?;
+        }
+
+        Ok(())
+    }
+
+    async fn test_aggregation_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+            .unwrap()
+            .aggregate(vec![col("id")], vec![sum(col("gmv"))])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        ExecutionGraph::new("localhost:50050", "job", "session", plan).unwrap()
+    }
+
+    async fn test_two_aggregations_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("name", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1, 2]))
+            .unwrap()
+            .aggregate(vec![col("id"), col("name")], vec![sum(col("gmv"))])
+            .unwrap()
+            .aggregate(vec![col("id")], vec![count(col("id"))])
+            .unwrap()
+            .build()
+            .unwrap();
+
+        let optimized_plan = ctx.optimize(&logical_plan).unwrap();
+
+        let plan = ctx.create_physical_plan(&optimized_plan).await.unwrap();
+
+        println!("{}", DisplayableExecutionPlan::new(plan.as_ref()).indent());
+
+        ExecutionGraph::new("localhost:50050", "job", "session", plan).unwrap()
+    }
+
+    async fn test_coalesce_plan(partition: usize) -> ExecutionGraph {
+        let config = SessionConfig::new().with_target_partitions(partition);
+        let ctx = Arc::new(SessionContext::with_config(config));
+
+        let schema = Schema::new(vec![
+            Field::new("id", DataType::Utf8, false),
+            Field::new("gmv", DataType::UInt64, false),
+        ]);
+
+        let logical_plan = scan_empty(None, &schema, Some(vec![0, 1]))
+            .unwrap()
             .limit(0, Some(1))
             .unwrap()
             .build()
@@ -1507,7 +2894,7 @@ mod test {
         }
     }
 
-    fn mock_completed_task(task: Task, executor_id: &str) -> TaskStatus {
+    fn mock_completed_task(task: TaskDescription, executor_id: &str) -> TaskStatus {
         let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
 
         let num_partitions = task
@@ -1532,16 +2919,57 @@ mod test {
 
         // Complete the task
         protobuf::TaskStatus {
-            status: Some(task_status::Status::Completed(protobuf::CompletedTask {
+            task_id: task.task_id as u32,
+            job_id: task.partition.job_id.clone(),
+            stage_id: task.partition.stage_id as u32,
+            stage_attempt_num: task.stage_attempt_num as u32,
+            partition_id: task.partition.partition_id as u32,
+            launch_time: 0,
+            start_exec_time: 0,
+            end_exec_time: 0,
+            metrics: vec![],
+            status: Some(task_status::Status::Successful(protobuf::SuccessfulTask {
                 executor_id: executor_id.to_owned(),
                 partitions,
             })),
+        }
+    }
+
+    fn mock_failed_task(task: TaskDescription, failed_task: FailedTask) -> TaskStatus {
+        let mut partitions: Vec<protobuf::ShuffleWritePartition> = vec![];
+
+        let num_partitions = task
+            .output_partitioning
+            .map(|p| p.partition_count())
+            .unwrap_or(1);
+
+        for partition_id in 0..num_partitions {
+            partitions.push(protobuf::ShuffleWritePartition {
+                partition_id: partition_id as u64,
+                path: format!(
+                    "/{}/{}/{}",
+                    task.partition.job_id,
+                    task.partition.stage_id,
+                    task.partition.partition_id
+                ),
+                num_batches: 1,
+                num_rows: 1,
+                num_bytes: 1,
+            })
+        }
+
+        // Fail the task
+        protobuf::TaskStatus {
+            task_id: task.task_id as u32,
+            job_id: task.partition.job_id.clone(),
+            stage_id: task.partition.stage_id as u32,
+            stage_attempt_num: task.stage_attempt_num as u32,
+            partition_id: task.partition.partition_id as u32,
+            launch_time: 0,
+            start_exec_time: 0,
+            end_exec_time: 0,
             metrics: vec![],
-            task_id: Some(protobuf::PartitionId {
-                job_id: task.partition.job_id.clone(),
-                stage_id: task.partition.stage_id as u32,
-                partition_id: task.partition.partition_id as u32,
-            }),
+            status: Some(task_status::Status::Failed(failed_task)),
         }
     }
 }
diff --git a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
index 267b05a9..b7ef0087 100644
--- a/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph/execution_stage.rs
@@ -15,10 +15,12 @@
 // specific language governing permissions and limitations
 // under the License.
 
-use std::collections::HashMap;
+use std::collections::{HashMap, HashSet};
 use std::convert::TryInto;
 use std::fmt::{Debug, Formatter};
+use std::iter::FromIterator;
 use std::sync::Arc;
+use std::time::{SystemTime, UNIX_EPOCH};
 
 use datafusion::physical_plan::display::DisplayableExecutionPlan;
 use datafusion::physical_plan::metrics::{MetricValue, MetricsSet};
@@ -29,8 +31,10 @@ use log::{debug, warn};
 
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning;
+use ballista_core::serde::protobuf::failed_task::FailedReason;
 use ballista_core::serde::protobuf::{
-    self, CompletedTask, FailedTask, GraphStageInput, OperatorMetricsSet,
+    self, task_info, FailedTask, GraphStageInput, OperatorMetricsSet, ResultLost,
+    SuccessfulTask, TaskStatus,
 };
 use ballista_core::serde::protobuf::{task_status, RunningTask};
 use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
@@ -47,13 +51,13 @@ use crate::display::DisplayableBallistaExecutionPlan;
 ///       ↓            ↙           ↑
 ///  ResolvedStage     →     RunningStage
 ///                                ↓
-///                         CompletedStage
+///                         SuccessfulStage
 #[derive(Clone)]
 pub(crate) enum ExecutionStage {
     UnResolved(UnresolvedStage),
     Resolved(ResolvedStage),
     Running(RunningStage),
-    Completed(CompletedStage),
+    Successful(SuccessfulStage),
     Failed(FailedStage),
 }
 
@@ -63,7 +67,7 @@ impl Debug for ExecutionStage {
             ExecutionStage::UnResolved(unresolved_stage) => unresolved_stage.fmt(f),
             ExecutionStage::Resolved(resolved_stage) => resolved_stage.fmt(f),
             ExecutionStage::Running(running_stage) => running_stage.fmt(f),
-            ExecutionStage::Completed(completed_stage) => completed_stage.fmt(f),
+            ExecutionStage::Successful(successful_stage) => successful_stage.fmt(f),
             ExecutionStage::Failed(failed_stage) => failed_stage.fmt(f),
         }
     }
@@ -74,6 +78,8 @@ impl Debug for ExecutionStage {
 pub(crate) struct UnresolvedStage {
     /// Stage ID
     pub(crate) stage_id: usize,
+    /// Stage Attempt number
+    pub(crate) stage_attempt_num: usize,
     /// Output partitioning for this stage.
     pub(crate) output_partitioning: Option<Partitioning>,
     /// Stage ID of the stage that will take this stages outputs as inputs.
@@ -84,6 +90,8 @@ pub(crate) struct UnresolvedStage {
     pub(crate) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(crate) plan: Arc<dyn ExecutionPlan>,
+    /// Record last attempt's failure reasons to avoid duplicate resubmits
+    pub(crate) last_attempt_failure_reasons: HashSet<String>,
 }
 
 /// For a stage, if it has no inputs or all of its input stages are completed,
@@ -92,7 +100,9 @@ pub(crate) struct UnresolvedStage {
 pub(crate) struct ResolvedStage {
     /// Stage ID
     pub(crate) stage_id: usize,
-    /// Total number of output partitions for this stage.
+    /// Stage Attempt number
+    pub(crate) stage_attempt_num: usize,
+    /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
     /// Output partitioning for this stage.
@@ -104,6 +114,8 @@ pub(crate) struct ResolvedStage {
     pub(crate) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(crate) plan: Arc<dyn ExecutionPlan>,
+    /// Record last attempt's failure reasons to avoid duplicate resubmits
+    pub(crate) last_attempt_failure_reasons: HashSet<String>,
 }
 
 /// Different from the resolved stage, a running stage will
@@ -115,7 +127,9 @@ pub(crate) struct ResolvedStage {
 pub(crate) struct RunningStage {
     /// Stage ID
     pub(crate) stage_id: usize,
-    /// Total number of output partitions for this stage.
+    /// Stage Attempt number
+    pub(crate) stage_attempt_num: usize,
+    /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
     /// Output partitioning for this stage.
@@ -127,18 +141,24 @@ pub(crate) struct RunningStage {
     pub(crate) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(crate) plan: Arc<dyn ExecutionPlan>,
-    /// Status of each already scheduled task. If status is None, the partition has not yet been scheduled
-    pub(crate) task_statuses: Vec<Option<task_status::Status>>,
+    /// TaskInfo of each already scheduled task. If info is None, the partition has not yet been scheduled.
+    /// The index of the Vec is the task's partition id
+    pub(crate) task_infos: Vec<Option<TaskInfo>>,
+    /// Track the number of failures for each partition's task attempts.
+    /// The index of the Vec is the task's partition id.
+    pub(crate) task_failure_numbers: Vec<usize>,
     /// Combined metrics of the already finished tasks in the stage, If it is None, no task is finished yet.
     pub(crate) stage_metrics: Option<Vec<MetricsSet>>,
 }
 
 /// If a stage finishes successfully, its task statuses and metrics will be finalized
 #[derive(Clone)]
-pub(crate) struct CompletedStage {
+pub(crate) struct SuccessfulStage {
     /// Stage ID
     pub(crate) stage_id: usize,
-    /// Total number of output partitions for this stage.
+    /// Stage Attempt number
+    pub(crate) stage_attempt_num: usize,
+    /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
     /// Output partitioning for this stage.
@@ -150,8 +170,9 @@ pub(crate) struct CompletedStage {
     pub(crate) inputs: HashMap<usize, StageOutput>,
     /// `ExecutionPlan` for this stage
     pub(crate) plan: Arc<dyn ExecutionPlan>,
-    /// Status of each already scheduled task.
-    pub(crate) task_statuses: Vec<task_status::Status>,
+    /// TaskInfo of each already successful task.
+    /// The index of the Vec is the task's partition id
+    pub(crate) task_infos: Vec<TaskInfo>,
     /// Combined metrics of the already finished tasks in the stage.
     pub(crate) stage_metrics: Vec<MetricsSet>,
 }
@@ -161,7 +182,9 @@ pub(crate) struct CompletedStage {
 pub(crate) struct FailedStage {
     /// Stage ID
     pub(crate) stage_id: usize,
-    /// Total number of output partitions for this stage.
+    /// Stage Attempt number
+    pub(crate) stage_attempt_num: usize,
+    /// Total number of partitions for this stage.
     /// This stage will produce on task for partition.
     pub(crate) partitions: usize,
     /// Output partitioning for this stage.
@@ -171,14 +194,33 @@ pub(crate) struct FailedStage {
     pub(crate) output_links: Vec<usize>,
     /// `ExecutionPlan` for this stage
     pub(crate) plan: Arc<dyn ExecutionPlan>,
-    /// Status of each already scheduled task. If status is None, the partition has not yet been scheduled
-    pub(crate) task_statuses: Vec<Option<task_status::Status>>,
+    /// TaskInfo of each already scheduled tasks. If info is None, the partition has not yet been scheduled
+    /// The index of the Vec is the task's partition id
+    pub(crate) task_infos: Vec<Option<TaskInfo>>,
     /// Combined metrics of the already finished tasks in the stage, If it is None, no task is finished yet.
     pub(crate) stage_metrics: Option<Vec<MetricsSet>>,
     /// Error message
     pub(crate) error_message: String,
 }
 
+#[derive(Clone)]
+pub(crate) struct TaskInfo {
+    /// Task ID
+    pub(super) task_id: usize,
+    /// Task scheduled time
+    pub(super) scheduled_time: u128,
+    /// Task launch time
+    pub(super) launch_time: u128,
+    /// Start execution time
+    pub(super) start_exec_time: u128,
+    /// Finish execution time
+    pub(super) end_exec_time: u128,
+    /// Task finish time
+    pub(super) finish_time: u128,
+    /// Task Status
+    pub(super) task_status: task_status::Status,
+}
+
 impl UnresolvedStage {
     pub(super) fn new(
         stage_id: usize,
@@ -194,26 +236,32 @@ impl UnresolvedStage {
 
         Self {
             stage_id,
+            stage_attempt_num: 0,
             output_partitioning,
             output_links,
             inputs,
             plan,
+            last_attempt_failure_reasons: Default::default(),
         }
     }
 
     pub(super) fn new_with_inputs(
         stage_id: usize,
+        stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
         output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         inputs: HashMap<usize, StageOutput>,
+        last_attempt_failure_reasons: HashSet<String>,
     ) -> Self {
         Self {
             stage_id,
+            stage_attempt_num,
             output_partitioning,
             output_links,
             inputs,
             plan,
+            last_attempt_failure_reasons,
         }
     }
 
@@ -234,6 +282,35 @@ impl UnresolvedStage {
         Ok(())
     }
 
+    /// Remove input partitions from an input stage on a given executor.
+    /// Return the HashSet of removed map partition ids
+    pub(super) fn remove_input_partitions(
+        &mut self,
+        input_stage_id: usize,
+        _input_partition_id: usize,
+        executor_id: &str,
+    ) -> Result<HashSet<usize>> {
+        if let Some(stage_output) = self.inputs.get_mut(&input_stage_id) {
+            let mut bad_map_partitions = HashSet::new();
+            stage_output
+                .partition_locations
+                .iter_mut()
+                .for_each(|(_partition, locs)| {
+                    locs.iter().for_each(|loc| {
+                        if loc.executor_meta.id == executor_id {
+                            bad_map_partitions.insert(loc.map_partition_id);
+                        }
+                    });
+
+                    locs.retain(|loc| loc.executor_meta.id != executor_id);
+                });
+            stage_output.complete = false;
+            Ok(bad_map_partitions)
+        } else {
+            Err(BallistaError::Internal(format!("Error remove input partition for Stage {}, {} is not a valid child stage ID", self.stage_id, input_stage_id)))
+        }
+    }
+
     /// Marks the input stage ID as complete.
     pub(super) fn complete_input(&mut self, stage_id: usize) {
         if let Some(input) = self.inputs.get_mut(&stage_id) {
@@ -260,10 +337,12 @@ impl UnresolvedStage {
         )?;
         Ok(ResolvedStage::new(
             self.stage_id,
+            self.stage_attempt_num,
             plan,
             self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
+            self.last_attempt_failure_reasons.clone(),
         ))
     }
 
@@ -289,10 +368,14 @@ impl UnresolvedStage {
 
         Ok(UnresolvedStage {
             stage_id: stage.stage_id as usize,
+            stage_attempt_num: stage.stage_attempt_num as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             plan,
             inputs,
+            last_attempt_failure_reasons: HashSet::from_iter(
+                stage.last_attempt_failure_reasons,
+            ),
         })
     }
 
@@ -310,11 +393,15 @@ impl UnresolvedStage {
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
 
         Ok(protobuf::UnResolvedStage {
-            stage_id: stage.stage_id as u64,
+            stage_id: stage.stage_id as u32,
+            stage_attempt_num: stage.stage_attempt_num as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
+            last_attempt_failure_reasons: Vec::from_iter(
+                stage.last_attempt_failure_reasons,
+            ),
         })
     }
 }
@@ -325,8 +412,9 @@ impl Debug for UnresolvedStage {
 
         write!(
             f,
-            "=========UnResolvedStage[id={}, children={}]=========\nInputs{:?}\n{}",
+            "=========UnResolvedStage[stage_id={}.{}, children={}]=========\nInputs{:?}\n{}",
             self.stage_id,
+            self.stage_attempt_num,
             self.inputs.len(),
             self.inputs,
             plan
@@ -337,20 +425,24 @@ impl Debug for UnresolvedStage {
 impl ResolvedStage {
     pub(super) fn new(
         stage_id: usize,
+        stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
         output_partitioning: Option<Partitioning>,
         output_links: Vec<usize>,
         inputs: HashMap<usize, StageOutput>,
+        last_attempt_failure_reasons: HashSet<String>,
     ) -> Self {
         let partitions = plan.output_partitioning().partition_count();
 
         Self {
             stage_id,
+            stage_attempt_num,
             partitions,
             output_partitioning,
             output_links,
             inputs,
             plan,
+            last_attempt_failure_reasons,
         }
     }
 
@@ -358,6 +450,7 @@ impl ResolvedStage {
     pub(super) fn to_running(&self) -> RunningStage {
         RunningStage::new(
             self.stage_id,
+            self.stage_attempt_num,
             self.plan.clone(),
             self.partitions,
             self.output_partitioning.clone(),
@@ -372,10 +465,12 @@ impl ResolvedStage {
 
         let unresolved = UnresolvedStage::new_with_inputs(
             self.stage_id,
+            self.stage_attempt_num,
             new_plan,
             self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
+            self.last_attempt_failure_reasons.clone(),
         );
         Ok(unresolved)
     }
@@ -402,11 +497,15 @@ impl ResolvedStage {
 
         Ok(ResolvedStage {
             stage_id: stage.stage_id as usize,
+            stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             inputs,
             plan,
+            last_attempt_failure_reasons: HashSet::from_iter(
+                stage.last_attempt_failure_reasons,
+            ),
         })
     }
 
@@ -424,12 +523,16 @@ impl ResolvedStage {
         let inputs = encode_inputs(stage.inputs)?;
 
         Ok(protobuf::ResolvedStage {
-            stage_id: stage.stage_id as u64,
+            stage_id: stage.stage_id as u32,
+            stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
+            last_attempt_failure_reasons: Vec::from_iter(
+                stage.last_attempt_failure_reasons,
+            ),
         })
     }
 }
@@ -440,8 +543,8 @@ impl Debug for ResolvedStage {
 
         write!(
             f,
-            "=========ResolvedStage[id={}, partitions={}]=========\n{}",
-            self.stage_id, self.partitions, plan
+            "=========ResolvedStage[stage_id={}.{}, partitions={}]=========\n{}",
+            self.stage_id, self.stage_attempt_num, self.partitions, plan
         )
     }
 }
@@ -449,6 +552,7 @@ impl Debug for ResolvedStage {
 impl RunningStage {
     pub(super) fn new(
         stage_id: usize,
+        stage_attempt_num: usize,
         plan: Arc<dyn ExecutionPlan>,
         partitions: usize,
         output_partitioning: Option<Partitioning>,
@@ -457,26 +561,28 @@ impl RunningStage {
     ) -> Self {
         Self {
             stage_id,
+            stage_attempt_num,
             partitions,
             output_partitioning,
             output_links,
             inputs,
             plan,
-            task_statuses: vec![None; partitions],
+            task_infos: vec![None; partitions],
+            task_failure_numbers: vec![0; partitions],
             stage_metrics: None,
         }
     }
 
-    pub(super) fn to_completed(&self) -> CompletedStage {
-        let task_statuses = self
-            .task_statuses
+    pub(super) fn to_successful(&self) -> SuccessfulStage {
+        let task_infos = self
+            .task_infos
             .iter()
             .enumerate()
-            .map(|(task_id, status)| {
-                status.clone().unwrap_or_else(|| {
+            .map(|(partition_id, info)| {
+                info.clone().unwrap_or_else(|| {
                     panic!(
-                        "The status of task {}/{} should not be none",
-                        self.stage_id, task_id
+                        "TaskInfo for task {}.{}/{} should not be none",
+                        self.stage_id, self.stage_attempt_num, partition_id
                     )
                 })
             })
@@ -485,14 +591,15 @@ impl RunningStage {
             warn!("The metrics for stage {} should not be none", self.stage_id);
             vec![]
         });
-        CompletedStage {
+        SuccessfulStage {
             stage_id: self.stage_id,
+            stage_attempt_num: self.stage_attempt_num,
             partitions: self.partitions,
             output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             inputs: self.inputs.clone(),
             plan: self.plan.clone(),
-            task_statuses,
+            task_infos,
             stage_metrics,
         }
     }
@@ -500,68 +607,92 @@ impl RunningStage {
     pub(super) fn to_failed(&self, error_message: String) -> FailedStage {
         FailedStage {
             stage_id: self.stage_id,
+            stage_attempt_num: self.stage_attempt_num,
             partitions: self.partitions,
             output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             plan: self.plan.clone(),
-            task_statuses: self.task_statuses.clone(),
+            task_infos: self.task_infos.clone(),
             stage_metrics: self.stage_metrics.clone(),
             error_message,
         }
     }
 
+    /// Change to the resolved state and bump the stage attempt number
     pub(super) fn to_resolved(&self) -> ResolvedStage {
         ResolvedStage::new(
             self.stage_id,
+            self.stage_attempt_num + 1,
             self.plan.clone(),
             self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
+            HashSet::new(),
         )
     }
 
-    /// Change to the unresolved state
-    pub(super) fn to_unresolved(&self) -> Result<UnresolvedStage> {
+    /// Change to the unresolved state and bump the stage attempt number
+    pub(super) fn to_unresolved(
+        &self,
+        failure_reasons: HashSet<String>,
+    ) -> Result<UnresolvedStage> {
         let new_plan = crate::planner::rollback_resolved_shuffles(self.plan.clone())?;
 
         let unresolved = UnresolvedStage::new_with_inputs(
             self.stage_id,
+            self.stage_attempt_num + 1,
             new_plan,
             self.output_partitioning.clone(),
             self.output_links.clone(),
             self.inputs.clone(),
+            failure_reasons,
         );
         Ok(unresolved)
     }
 
-    /// Returns `true` if all tasks for this stage are complete
-    pub(super) fn is_completed(&self) -> bool {
-        self.task_statuses
-            .iter()
-            .all(|status| matches!(status, Some(task_status::Status::Completed(_))))
+    /// Returns `true` if all tasks for this stage are successful
+    pub(super) fn is_successful(&self) -> bool {
+        self.task_infos.iter().all(|info| {
+            matches!(
+                info,
+                Some(TaskInfo {
+                    task_status: task_status::Status::Successful(_),
+                    ..
+                })
+            )
+        })
     }
 
-    /// Returns the number of completed tasks
-    pub(super) fn completed_tasks(&self) -> usize {
-        self.task_statuses
+    /// Returns the number of successful tasks
+    pub(super) fn successful_tasks(&self) -> usize {
+        self.task_infos
             .iter()
-            .filter(|status| matches!(status, Some(task_status::Status::Completed(_))))
+            .filter(|info| {
+                matches!(
+                    info,
+                    Some(TaskInfo {
+                        task_status: task_status::Status::Successful(_),
+                        ..
+                    })
+                )
+            })
             .count()
     }
 
     /// Returns the number of scheduled tasks
     pub(super) fn scheduled_tasks(&self) -> usize {
-        self.task_statuses.iter().filter(|s| s.is_some()).count()
+        self.task_infos.iter().filter(|s| s.is_some()).count()
     }
 
     /// Returns a vector of currently running tasks in this stage
-    pub(super) fn running_tasks(&self) -> Vec<(usize, usize, String)> {
-        self.task_statuses
+    pub(super) fn running_tasks(&self) -> Vec<(usize, usize, usize, String)> {
+        self.task_infos
             .iter()
             .enumerate()
-            .filter_map(|(partition, status)| match status {
-                Some(task_status::Status::Running(RunningTask { executor_id })) => {
-                    Some((self.stage_id, partition, executor_id.clone()))
+            .filter_map(|(partition, info)| match info {
+                Some(TaskInfo {task_id,
+                         task_status: task_status::Status::Running(RunningTask { executor_id }), ..}) => {
+                    Some((*task_id, self.stage_id, partition, executor_id.clone()))
                 }
                 _ => None,
             })
@@ -570,19 +701,50 @@ impl RunningStage {
 
     /// Returns the number of tasks in this stage which are available for scheduling.
     /// If the stage is not yet resolved, then this will return `0`, otherwise it will
-    /// return the number of tasks where the task status is not yet set.
+    /// return the number of tasks where the task info is not yet set.
     pub(super) fn available_tasks(&self) -> usize {
-        self.task_statuses.iter().filter(|s| s.is_none()).count()
+        self.task_infos.iter().filter(|s| s.is_none()).count()
     }
 
-    /// Update the status for task partition
-    pub(super) fn update_task_status(
+    /// Update the TaskInfo for task partition
+    pub(super) fn update_task_info(
         &mut self,
         partition_id: usize,
-        status: task_status::Status,
-    ) {
-        debug!("Updating task status for partition {}", partition_id);
-        self.task_statuses[partition_id] = Some(status);
+        status: TaskStatus,
+    ) -> bool {
+        debug!("Updating TaskInfo for partition {}", partition_id);
+        let task_info = self.task_infos[partition_id].as_ref().unwrap();
+        let task_id = task_info.task_id;
+        if (status.task_id as usize) < task_id {
+            warn!("Ignore TaskStatus update with TID {} because there is more recent task attempt with TID {} running for partition {}",
+                status.task_id, task_id, partition_id);
+            return false;
+        }
+        let scheduled_time = task_info.scheduled_time;
+        let task_status = status.status.unwrap();
+        let updated_task_info = TaskInfo {
+            task_id,
+            scheduled_time,
+            launch_time: status.launch_time as u128,
+            start_exec_time: status.start_exec_time as u128,
+            end_exec_time: status.end_exec_time as u128,
+            finish_time: SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .unwrap()
+                .as_millis(),
+            task_status: task_status.clone(),
+        };
+        self.task_infos[partition_id] = Some(updated_task_info);
+
+        if let task_status::Status::Failed(failed_task) = task_status {
+            // if the failed task is retryable, increase the task failure count for this partition
+            if failed_task.retryable {
+                self.task_failure_numbers[partition_id] += 1;
+            }
+        } else {
+            self.task_failure_numbers[partition_id] = 0;
+        }
+        true
     }
 
     /// update and combine the task metrics to the stage metrics
@@ -639,22 +801,37 @@ impl RunningStage {
         first.aggregate_by_partition()
     }
 
+    pub(super) fn task_failure_number(&self, partition_id: usize) -> usize {
+        self.task_failure_numbers[partition_id]
+    }
+
+    /// Reset the task info for the given task partition. This should be called when a task failed and need to be
+    /// re-scheduled.
+    pub fn reset_task_info(&mut self, partition_id: usize) {
+        self.task_infos[partition_id] = None;
+    }
+
     /// Reset the running and completed tasks on a given executor
     /// Returns the number of running tasks that were reset
     pub fn reset_tasks(&mut self, executor: &str) -> usize {
         let mut reset = 0;
-        for task in self.task_statuses.iter_mut() {
+        for task in self.task_infos.iter_mut() {
             match task {
-                Some(task_status::Status::Running(RunningTask { executor_id }))
-                    if *executor == *executor_id =>
-                {
+                Some(TaskInfo {
+                    task_status: task_status::Status::Running(RunningTask { executor_id }),
+                    ..
+                }) if *executor == *executor_id => {
                     *task = None;
                     reset += 1;
                 }
-                Some(task_status::Status::Completed(CompletedTask {
-                    executor_id,
-                    partitions: _,
-                })) if *executor == *executor_id => {
+                Some(TaskInfo {
+                    task_status:
+                        task_status::Status::Successful(SuccessfulTask {
+                            executor_id,
+                            partitions: _,
+                        }),
+                    ..
+                }) if *executor == *executor_id => {
                     *task = None;
                     reset += 1;
                 }
@@ -663,6 +840,35 @@ impl RunningStage {
         }
         reset
     }
+
+    /// Remove input partitions from an input stage on a given executor.
+    /// Return the HashSet of removed map partition ids
+    pub(super) fn remove_input_partitions(
+        &mut self,
+        input_stage_id: usize,
+        _input_partition_id: usize,
+        executor_id: &str,
+    ) -> Result<HashSet<usize>> {
+        if let Some(stage_output) = self.inputs.get_mut(&input_stage_id) {
+            let mut bad_map_partitions = HashSet::new();
+            stage_output
+                .partition_locations
+                .iter_mut()
+                .for_each(|(_partition, locs)| {
+                    locs.iter().for_each(|loc| {
+                        if loc.executor_meta.id == executor_id {
+                            bad_map_partitions.insert(loc.map_partition_id);
+                        }
+                    });
+
+                    locs.retain(|loc| loc.executor_meta.id != executor_id);
+                });
+            stage_output.complete = false;
+            Ok(bad_map_partitions)
+        } else {
+            Err(BallistaError::Internal(format!("Error remove input partition for Stage {}, {} is not a valid child stage ID", self.stage_id, input_stage_id)))
+        }
+    }
 }
 
 impl Debug for RunningStage {
@@ -671,10 +877,11 @@ impl Debug for RunningStage {
 
         write!(
             f,
-            "=========RunningStage[id={}, partitions={}, completed_tasks={}, scheduled_tasks={}, available_tasks={}]=========\n{}",
+            "=========RunningStage[stage_id={}.{}, partitions={}, successful_tasks={}, scheduled_tasks={}, available_tasks={}]=========\n{}",
             self.stage_id,
+            self.stage_attempt_num,
             self.partitions,
-            self.completed_tasks(),
+            self.successful_tasks(),
             self.scheduled_tasks(),
             self.available_tasks(),
             plan
@@ -682,41 +889,64 @@ impl Debug for RunningStage {
     }
 }
 
-impl CompletedStage {
+impl SuccessfulStage {
+    /// Change to the running state and bump the stage attempt number
     pub fn to_running(&self) -> RunningStage {
-        let mut task_status: Vec<Option<task_status::Status>> = Vec::new();
-        for task in self.task_statuses.iter() {
+        let mut task_infos: Vec<Option<TaskInfo>> = Vec::new();
+        for task in self.task_infos.iter() {
             match task {
-                task_status::Status::Completed(_) => task_status.push(Some(task.clone())),
-                _ => task_status.push(None),
+                TaskInfo {
+                    task_status: task_status::Status::Successful(_),
+                    ..
+                } => task_infos.push(Some(task.clone())),
+                _ => task_infos.push(None),
             }
         }
         RunningStage {
             stage_id: self.stage_id,
+            stage_attempt_num: self.stage_attempt_num + 1,
             partitions: self.partitions,
             output_partitioning: self.output_partitioning.clone(),
             output_links: self.output_links.clone(),
             inputs: self.inputs.clone(),
             plan: self.plan.clone(),
-            task_statuses: task_status,
+            task_infos,
+            // It is Ok to forget the previous task failure attempts
+            task_failure_numbers: vec![0; self.partitions],
             stage_metrics: Some(self.stage_metrics.clone()),
         }
     }
 
-    /// Reset the completed tasks on a given executor
+    /// Reset the successful tasks on a given executor
     /// Returns the number of running tasks that were reset
     pub fn reset_tasks(&mut self, executor: &str) -> usize {
         let mut reset = 0;
         let failure_reason = format!("Task failure due to Executor {} lost", executor);
-        for task in self.task_statuses.iter_mut() {
+        for task in self.task_infos.iter_mut() {
             match task {
-                task_status::Status::Completed(CompletedTask {
-                    executor_id,
-                    partitions: _,
-                }) if *executor == *executor_id => {
-                    *task = task_status::Status::Failed(FailedTask {
-                        error: failure_reason.clone(),
-                    });
+                TaskInfo {
+                    task_id,
+                    scheduled_time,
+                    task_status:
+                        task_status::Status::Successful(SuccessfulTask {
+                            executor_id, ..
+                        }),
+                    ..
+                } if *executor == *executor_id => {
+                    *task = TaskInfo {
+                        task_id: *task_id,
+                        scheduled_time: *scheduled_time,
+                        launch_time: 0,
+                        start_exec_time: 0,
+                        end_exec_time: 0,
+                        finish_time: 0,
+                        task_status: task_status::Status::Failed(FailedTask {
+                            error: failure_reason.clone(),
+                            retryable: true,
+                            count_to_failures: false,
+                            failed_reason: Some(FailedReason::ResultLost(ResultLost {})),
+                        }),
+                    };
                     reset += 1;
                 }
                 _ => {}
@@ -726,10 +956,10 @@ impl CompletedStage {
     }
 
     pub(super) fn decode<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
-        stage: protobuf::CompletedStage,
+        stage: protobuf::SuccessfulStage,
         codec: &BallistaCodec<T, U>,
         session_ctx: &SessionContext,
-    ) -> Result<CompletedStage> {
+    ) -> Result<SuccessfulStage> {
         let plan_proto = U::try_decode(&stage.plan)?;
         let plan = plan_proto.try_into_physical_plan(
             session_ctx,
@@ -744,41 +974,36 @@ impl CompletedStage {
         )?;
 
         let inputs = decode_inputs(stage.inputs)?;
-
-        let task_statuses = stage
-            .task_statuses
-            .into_iter()
-            .enumerate()
-            .map(|(task_id, status)| {
-                status.status.unwrap_or_else(|| {
-                    panic!("Status for task {} should not be none", task_id)
-                })
-            })
-            .collect();
-
+        assert_eq!(
+            stage.task_infos.len(),
+            stage.partitions as usize,
+            "protobuf::SuccessfulStage task_infos len not equal to partitions."
+        );
+        let task_infos = stage.task_infos.into_iter().map(decode_taskinfo).collect();
         let stage_metrics = stage
             .stage_metrics
             .into_iter()
             .map(|m| m.try_into())
             .collect::<Result<Vec<_>>>()?;
 
-        Ok(CompletedStage {
+        Ok(SuccessfulStage {
             stage_id: stage.stage_id as usize,
+            stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             inputs,
             plan,
-            task_statuses,
+            task_infos,
             stage_metrics,
         })
     }
 
     pub(super) fn encode<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
-        job_id: String,
-        stage: CompletedStage,
+        _job_id: String,
+        stage: SuccessfulStage,
         codec: &BallistaCodec<T, U>,
-    ) -> Result<protobuf::CompletedStage> {
+    ) -> Result<protobuf::SuccessfulStage> {
         let stage_id = stage.stage_id;
 
         let mut plan: Vec<u8> = vec![];
@@ -789,23 +1014,11 @@ impl CompletedStage {
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
 
         let inputs = encode_inputs(stage.inputs)?;
-
-        let task_statuses: Vec<protobuf::TaskStatus> = stage
-            .task_statuses
+        let task_infos = stage
+            .task_infos
             .into_iter()
             .enumerate()
-            .map(|(partition, status)| {
-                protobuf::TaskStatus {
-                    task_id: Some(protobuf::PartitionId {
-                        job_id: job_id.clone(),
-                        stage_id: stage_id as u32,
-                        partition_id: partition as u32,
-                    }),
-                    // task metrics should not persist.
-                    metrics: vec![],
-                    status: Some(status),
-                }
-            })
+            .map(|(partition, task_info)| encode_taskinfo(task_info, partition))
             .collect();
 
         let stage_metrics = stage
@@ -814,20 +1027,21 @@ impl CompletedStage {
             .map(|m| m.try_into())
             .collect::<Result<Vec<_>>>()?;
 
-        Ok(protobuf::CompletedStage {
-            stage_id: stage_id as u64,
+        Ok(protobuf::SuccessfulStage {
+            stage_id: stage_id as u32,
+            stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             inputs,
             plan,
-            task_statuses,
+            task_infos,
             stage_metrics,
         })
     }
 }
 
-impl Debug for CompletedStage {
+impl Debug for SuccessfulStage {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         let plan = DisplayableBallistaExecutionPlan::new(
             self.plan.as_ref(),
@@ -837,31 +1051,38 @@ impl Debug for CompletedStage {
 
         write!(
             f,
-            "=========CompletedStage[id={}, partitions={}]=========\n{}",
-            self.stage_id, self.partitions, plan
+            "=========SuccessfulStage[stage_id={}.{}, partitions={}]=========\n{}",
+            self.stage_id, self.stage_attempt_num, self.partitions, plan
         )
     }
 }
 
 impl FailedStage {
-    /// Returns the number of completed tasks
-    pub(super) fn completed_tasks(&self) -> usize {
-        self.task_statuses
+    /// Returns the number of successful tasks
+    pub(super) fn successful_tasks(&self) -> usize {
+        self.task_infos
             .iter()
-            .filter(|status| matches!(status, Some(task_status::Status::Completed(_))))
+            .filter(|info| {
+                matches!(
+                    info,
+                    Some(TaskInfo {
+                        task_status: task_status::Status::Successful(_),
+                        ..
+                    })
+                )
+            })
             .count()
     }
-
     /// Returns the number of scheduled tasks
     pub(super) fn scheduled_tasks(&self) -> usize {
-        self.task_statuses.iter().filter(|s| s.is_some()).count()
+        self.task_infos.iter().filter(|s| s.is_some()).count()
     }
 
     /// Returns the number of tasks in this stage which are available for scheduling.
     /// If the stage is not yet resolved, then this will return `0`, otherwise it will
     /// return the number of tasks where the task status is not yet set.
     pub(super) fn available_tasks(&self) -> usize {
-        self.task_statuses.iter().filter(|s| s.is_none()).count()
+        self.task_infos.iter().filter(|s| s.is_none()).count()
     }
 
     pub(super) fn decode<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
@@ -882,12 +1103,9 @@ impl FailedStage {
             plan.schema().as_ref(),
         )?;
 
-        let mut task_statuses: Vec<Option<task_status::Status>> =
-            vec![None; stage.partitions as usize];
-        for status in stage.task_statuses {
-            if let Some(task_id) = status.task_id.as_ref() {
-                task_statuses[task_id.partition_id as usize] = status.status
-            }
+        let mut task_infos: Vec<Option<TaskInfo>> = vec![None; stage.partitions as usize];
+        for info in stage.task_infos {
+            task_infos[info.partition_id as usize] = Some(decode_taskinfo(info.clone()));
         }
 
         let stage_metrics = if stage.stage_metrics.is_empty() {
@@ -903,18 +1121,19 @@ impl FailedStage {
 
         Ok(FailedStage {
             stage_id: stage.stage_id as usize,
+            stage_attempt_num: stage.stage_attempt_num as usize,
             partitions: stage.partitions as usize,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as usize).collect(),
             plan,
-            task_statuses,
+            task_infos,
             stage_metrics,
             error_message: stage.error_message,
         })
     }
 
     pub(super) fn encode<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>(
-        job_id: String,
+        _job_id: String,
         stage: FailedStage,
         codec: &BallistaCodec<T, U>,
     ) -> Result<protobuf::FailedStage> {
@@ -927,21 +1146,12 @@ impl FailedStage {
         let output_partitioning =
             hash_partitioning_to_proto(stage.output_partitioning.as_ref())?;
 
-        let task_statuses: Vec<protobuf::TaskStatus> = stage
-            .task_statuses
+        let task_infos: Vec<protobuf::TaskInfo> = stage
+            .task_infos
             .into_iter()
             .enumerate()
-            .filter_map(|(partition, status)| {
-                status.map(|status| protobuf::TaskStatus {
-                    task_id: Some(protobuf::PartitionId {
-                        job_id: job_id.clone(),
-                        stage_id: stage_id as u32,
-                        partition_id: partition as u32,
-                    }),
-                    // task metrics should not persist.
-                    metrics: vec![],
-                    status: Some(status),
-                })
+            .filter_map(|(partition, task_info)| {
+                task_info.map(|info| encode_taskinfo(info, partition))
             })
             .collect();
 
@@ -953,12 +1163,13 @@ impl FailedStage {
             .collect::<Result<Vec<_>>>()?;
 
         Ok(protobuf::FailedStage {
-            stage_id: stage_id as u64,
+            stage_id: stage_id as u32,
+            stage_attempt_num: stage.stage_attempt_num as u32,
             partitions: stage.partitions as u32,
             output_partitioning,
             output_links: stage.output_links.into_iter().map(|l| l as u32).collect(),
             plan,
-            task_statuses,
+            task_infos,
             stage_metrics,
             error_message: stage.error_message,
         })
@@ -971,10 +1182,11 @@ impl Debug for FailedStage {
 
         write!(
             f,
-            "=========FailedStage[id={}, partitions={}, completed_tasks={}, scheduled_tasks={}, available_tasks={}, error_message={}]=========\n{}",
+            "=========FailedStage[stage_id={}.{}, partitions={}, successful_tasks={}, scheduled_tasks={}, available_tasks={}, error_message={}]=========\n{}",
             self.stage_id,
+            self.stage_attempt_num,
             self.partitions,
-            self.completed_tasks(),
+            self.successful_tasks(),
             self.scheduled_tasks(),
             self.available_tasks(),
             self.error_message,
@@ -1080,3 +1292,48 @@ fn encode_inputs(
     }
     Ok(inputs)
 }
+
+fn decode_taskinfo(task_info: protobuf::TaskInfo) -> TaskInfo {
+    let task_info_status = match task_info.status {
+        Some(task_info::Status::Running(running)) => {
+            task_status::Status::Running(running)
+        }
+        Some(task_info::Status::Failed(failed)) => task_status::Status::Failed(failed),
+        Some(task_info::Status::Successful(success)) => {
+            task_status::Status::Successful(success)
+        }
+        _ => panic!(
+            "protobuf::TaskInfo status for task {} should not be none",
+            task_info.task_id
+        ),
+    };
+    TaskInfo {
+        task_id: task_info.task_id as usize,
+        scheduled_time: task_info.scheduled_time as u128,
+        launch_time: task_info.launch_time as u128,
+        start_exec_time: task_info.start_exec_time as u128,
+        end_exec_time: task_info.end_exec_time as u128,
+        finish_time: task_info.finish_time as u128,
+        task_status: task_info_status,
+    }
+}
+
+fn encode_taskinfo(task_info: TaskInfo, partition_id: usize) -> protobuf::TaskInfo {
+    let task_info_status = match task_info.task_status {
+        task_status::Status::Running(running) => task_info::Status::Running(running),
+        task_status::Status::Failed(failed) => task_info::Status::Failed(failed),
+        task_status::Status::Successful(success) => {
+            task_info::Status::Successful(success)
+        }
+    };
+    protobuf::TaskInfo {
+        task_id: task_info.task_id as u32,
+        partition_id: partition_id as u32,
+        scheduled_time: task_info.scheduled_time as u64,
+        launch_time: task_info.launch_time as u64,
+        start_exec_time: task_info.start_exec_time as u64,
+        end_exec_time: task_info.end_exec_time as u64,
+        finish_time: task_info.finish_time as u64,
+        status: Some(task_info_status),
+    }
+}
diff --git a/ballista/rust/scheduler/src/state/execution_graph_dot.rs b/ballista/rust/scheduler/src/state/execution_graph_dot.rs
index c09789e3..7962b21a 100644
--- a/ballista/rust/scheduler/src/state/execution_graph_dot.rs
+++ b/ballista/rust/scheduler/src/state/execution_graph_dot.rs
@@ -102,7 +102,7 @@ impl ExecutionGraphDot {
                         0,
                     )?);
                 }
-                ExecutionStage::Completed(stage) => {
+                ExecutionStage::Successful(stage) => {
                     writeln!(&mut dot, "\t\tlabel = \"Stage {} [Completed]\";", id)?;
                     stage_meta.push(write_stage_plan(
                         &mut dot,
diff --git a/ballista/rust/scheduler/src/state/executor_manager.rs b/ballista/rust/scheduler/src/state/executor_manager.rs
index 9fc8df90..1d135ef8 100644
--- a/ballista/rust/scheduler/src/state/executor_manager.rs
+++ b/ballista/rust/scheduler/src/state/executor_manager.rs
@@ -23,14 +23,15 @@ use crate::state::{decode_into, decode_protobuf, encode_protobuf, with_lock};
 use ballista_core::error::{BallistaError, Result};
 use ballista_core::serde::protobuf;
 
+use crate::state::execution_graph::RunningTaskInfo;
 use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
 use ballista_core::serde::protobuf::{
-    executor_status, ExecutorHeartbeat, ExecutorStatus,
+    executor_status, CancelTasksParams, ExecutorHeartbeat, ExecutorStatus,
 };
 use ballista_core::serde::scheduler::{ExecutorData, ExecutorMetadata};
 use ballista_core::utils::create_grpc_client_connection;
 use futures::StreamExt;
-use log::{debug, info};
+use log::{debug, error, info};
 use parking_lot::RwLock;
 use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
@@ -216,6 +217,47 @@ impl ExecutorManager {
         .await
     }
 
+    /// Send rpc to Executors to cancel the running tasks
+    pub async fn cancel_running_tasks(&self, tasks: Vec<RunningTaskInfo>) -> Result<()> {
+        let mut tasks_to_cancel: HashMap<&str, Vec<protobuf::RunningTaskInfo>> =
+            Default::default();
+
+        for task_info in &tasks {
+            if let Some(infos) = tasks_to_cancel.get_mut(task_info.executor_id.as_str()) {
+                infos.push(protobuf::RunningTaskInfo {
+                    task_id: task_info.task_id as u32,
+                    job_id: task_info.job_id.clone(),
+                    stage_id: task_info.stage_id as u32,
+                    partition_id: task_info.partition_id as u32,
+                })
+            } else {
+                tasks_to_cancel.insert(
+                    task_info.executor_id.as_str(),
+                    vec![protobuf::RunningTaskInfo {
+                        task_id: task_info.task_id as u32,
+                        job_id: task_info.job_id.clone(),
+                        stage_id: task_info.stage_id as u32,
+                        partition_id: task_info.partition_id as u32,
+                    }],
+                );
+            }
+        }
+
+        for (executor_id, infos) in tasks_to_cancel {
+            if let Ok(mut client) = self.get_client(executor_id).await {
+                client
+                    .cancel_tasks(CancelTasksParams { task_infos: infos })
+                    .await?;
+            } else {
+                error!(
+                    "Failed to get client for executor ID {} to cancel tasks",
+                    executor_id
+                )
+            }
+        }
+        Ok(())
+    }
+
     pub async fn get_client(
         &self,
         executor_id: &str,
diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs
index f5fe410e..1ced943e 100644
--- a/ballista/rust/scheduler/src/state/mod.rs
+++ b/ballista/rust/scheduler/src/state/mod.rs
@@ -288,8 +288,7 @@ mod test {
     use ballista_core::config::{BallistaConfig, BALLISTA_DEFAULT_SHUFFLE_PARTITIONS};
     use ballista_core::error::Result;
     use ballista_core::serde::protobuf::{
-        task_status, CompletedTask, PartitionId, PhysicalPlanNode, ShuffleWritePartition,
-        TaskStatus,
+        task_status, PhysicalPlanNode, ShuffleWritePartition, SuccessfulTask, TaskStatus,
     };
     use ballista_core::serde::scheduler::{
         ExecutorData, ExecutorMetadata, ExecutorSpecification,
@@ -423,36 +422,49 @@ mod test {
         let (executor_metadata, executor_data) = executors[0].clone();
 
         // Complete the first stage. So we should now have 4 pending tasks for this job stage 2
-        let mut partitions: Vec<ShuffleWritePartition> = vec![];
-
-        for partition_id in 0..4 {
-            partitions.push(ShuffleWritePartition {
-                partition_id: partition_id as u64,
-                path: "some/path".to_string(),
-                num_batches: 1,
-                num_rows: 1,
-                num_bytes: 1,
-            })
-        }
-
-        state
-            .task_manager
-            .update_task_statuses(
-                &executor_metadata,
-                vec![TaskStatus {
-                    task_id: Some(PartitionId {
+        {
+            let plan_graph = state
+                .task_manager
+                .get_active_execution_graph("job-1")
+                .await
+                .unwrap();
+            let task_def = plan_graph
+                .write()
+                .await
+                .pop_next_task(&executor_data.executor_id)?
+                .unwrap();
+            let mut partitions: Vec<ShuffleWritePartition> = vec![];
+            for partition_id in 0..4 {
+                partitions.push(ShuffleWritePartition {
+                    partition_id: partition_id as u64,
+                    path: "some/path".to_string(),
+                    num_batches: 1,
+                    num_rows: 1,
+                    num_bytes: 1,
+                })
+            }
+            state
+                .task_manager
+                .update_task_statuses(
+                    &executor_metadata,
+                    vec![TaskStatus {
+                        task_id: task_def.task_id as u32,
                         job_id: "job-1".to_string(),
-                        stage_id: 1,
-                        partition_id: 0,
-                    }),
-                    metrics: vec![],
-                    status: Some(task_status::Status::Completed(CompletedTask {
-                        executor_id: "executor-1".to_string(),
-                        partitions,
-                    })),
-                }],
-            )
-            .await?;
+                        stage_id: task_def.partition.stage_id as u32,
+                        stage_attempt_num: task_def.stage_attempt_num as u32,
+                        partition_id: task_def.partition.partition_id as u32,
+                        launch_time: 0,
+                        start_exec_time: 0,
+                        end_exec_time: 0,
+                        metrics: vec![],
+                        status: Some(task_status::Status::Successful(SuccessfulTask {
+                            executor_id: executor_data.executor_id.clone(),
+                            partitions,
+                        })),
+                    }],
+                )
+                .await?;
+        }
 
         state
             .executor_manager
diff --git a/ballista/rust/scheduler/src/state/task_manager.rs b/ballista/rust/scheduler/src/state/task_manager.rs
index 9f9bcb5a..9a95775d 100644
--- a/ballista/rust/scheduler/src/state/task_manager.rs
+++ b/ballista/rust/scheduler/src/state/task_manager.rs
@@ -18,18 +18,19 @@
 use crate::scheduler_server::event::QueryStageSchedulerEvent;
 use crate::scheduler_server::SessionBuilder;
 use crate::state::backend::{Keyspace, Lock, StateBackendClient};
-use crate::state::execution_graph::{ExecutionGraph, ExecutionStage, Task};
+use crate::state::execution_graph::{
+    ExecutionGraph, ExecutionStage, RunningTaskInfo, TaskDescription,
+};
 use crate::state::executor_manager::{ExecutorManager, ExecutorReservation};
 use crate::state::{decode_protobuf, encode_protobuf, with_lock};
 use ballista_core::config::BallistaConfig;
 #[cfg(not(test))]
 use ballista_core::error::BallistaError;
 use ballista_core::error::Result;
-use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient;
 
 use crate::state::session_manager::create_datafusion_context;
 use ballista_core::serde::protobuf::{
-    self, job_status, CancelTasksParams, FailedJob, JobStatus, TaskDefinition, TaskStatus,
+    self, job_status, FailedJob, JobStatus, TaskDefinition, TaskStatus,
 };
 use ballista_core::serde::scheduler::to_proto::hash_partitioning_to_proto;
 use ballista_core::serde::scheduler::ExecutorMetadata;
@@ -40,20 +41,22 @@ use datafusion_proto::logical_plan::AsLogicalPlan;
 use log::{debug, error, info, warn};
 use rand::distributions::Alphanumeric;
 use rand::{thread_rng, Rng};
-use std::collections::HashMap;
-use std::default::Default;
+use std::collections::{HashMap, HashSet};
 use std::sync::Arc;
+use std::time::{SystemTime, UNIX_EPOCH};
 use tokio::sync::RwLock;
-use tonic::transport::Channel;
 
-type ExecutorClients = Arc<RwLock<HashMap<String, ExecutorGrpcClient<Channel>>>>;
 type ExecutionGraphCache = Arc<RwLock<HashMap<String, Arc<RwLock<ExecutionGraph>>>>>;
 
+// TODO move to configuration file
+/// Default max failure attempts for task level retry
+pub const TASK_MAX_FAILURES: usize = 4;
+/// Default max failure attempts for stage level retry
+pub const STAGE_MAX_FAILURES: usize = 4;
+
 #[derive(Clone)]
 pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> {
     state: Arc<dyn StateBackendClient>,
-    #[allow(dead_code)]
-    clients: ExecutorClients,
     session_builder: SessionBuilder,
     codec: BallistaCodec<T, U>,
     scheduler_id: String,
@@ -61,6 +64,15 @@ pub struct TaskManager<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan>
     active_job_cache: ExecutionGraphCache,
 }
 
+#[derive(Clone)]
+pub struct UpdatedStages {
+    pub resolved_stages: HashSet<usize>,
+    pub successful_stages: HashSet<usize>,
+    pub failed_stages: HashMap<usize, String>,
+    pub rollback_running_stages: HashMap<usize, HashSet<String>>,
+    pub resubmit_successful_stages: HashSet<usize>,
+}
+
 impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U> {
     pub fn new(
         state: Arc<dyn StateBackendClient>,
@@ -70,7 +82,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
     ) -> Self {
         Self {
             state,
-            clients: Default::default(),
             session_builder,
             codec,
             scheduler_id,
@@ -124,7 +135,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
             let graph = self.get_execution_graph(job_id).await?;
             let mut completed_stages = 0;
             for stage in graph.stages().values() {
-                if let ExecutionStage::Completed(_) = stage {
+                if let ExecutionStage::Successful(_) = stage {
                     completed_stages += 1;
                 }
             }
@@ -185,13 +196,9 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
         let mut job_updates: HashMap<String, Vec<TaskStatus>> = HashMap::new();
         for status in task_status {
             debug!("Task Update\n{:?}", status);
-            if let Some(job_id) = status.task_id.as_ref().map(|id| &id.job_id) {
-                let job_task_statuses =
-                    job_updates.entry(job_id.clone()).or_insert_with(Vec::new);
-                job_task_statuses.push(status);
-            } else {
-                warn!("Received task with no job ID");
-            }
+            let job_id = status.job_id.clone();
+            let job_task_statuses = job_updates.entry(job_id).or_insert_with(Vec::new);
+            job_task_statuses.push(status);
         }
 
         let mut events: Vec<QueryStageSchedulerEvent> = vec![];
@@ -200,16 +207,21 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
             debug!("Updating {} tasks in job {}", num_tasks, job_id);
 
             let graph = self.get_active_execution_graph(&job_id).await;
-            let job_event = if let Some(graph) = graph {
+            let job_events = if let Some(graph) = graph {
                 let mut graph = graph.write().await;
-                graph.update_task_status(executor, statuses)?
+                graph.update_task_status(
+                    executor,
+                    statuses,
+                    TASK_MAX_FAILURES,
+                    STAGE_MAX_FAILURES,
+                )?
             } else {
                 // TODO Deal with curator changed case
                 error!("Fail to find job {} in the active cache and it may not be curated by this scheduler", job_id);
-                None
+                vec![]
             };
 
-            if let Some(event) = job_event {
+            for event in job_events {
                 events.push(event);
             }
         }
@@ -232,7 +244,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
     pub async fn fill_reservations(
         &self,
         reservations: &[ExecutorReservation],
-    ) -> Result<(Vec<(String, Task)>, Vec<ExecutorReservation>, usize)> {
+    ) -> Result<(
+        Vec<(String, TaskDescription)>,
+        Vec<ExecutorReservation>,
+        usize,
+    )> {
         // Reinitialize the free reservations.
         let free_reservations: Vec<ExecutorReservation> = reservations
             .iter()
@@ -241,7 +257,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
             })
             .collect();
 
-        let mut assignments: Vec<(String, Task)> = vec![];
+        let mut assignments: Vec<(String, TaskDescription)> = vec![];
         let mut pending_tasks = 0usize;
         let mut assign_tasks = 0usize;
         let job_cache = self.active_job_cache.read().await;
@@ -268,16 +284,16 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
         Ok((assignments, unassigned, pending_tasks))
     }
 
-    /// Mark a job as completed. This will create a key under the CompletedJobs keyspace
+    /// Mark a job to success. This will create a key under the CompletedJobs keyspace
     /// and remove the job from ActiveJobs
-    pub async fn complete_job(&self, job_id: &str) -> Result<()> {
-        debug!("Moving job {} from Active to Completed", job_id);
+    pub async fn succeed_job(&self, job_id: &str) -> Result<()> {
+        debug!("Moving job {} from Active to Success", job_id);
         let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
         with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?;
 
         if let Some(graph) = self.get_active_execution_graph(job_id).await {
             let graph = graph.read().await.clone();
-            if graph.complete() {
+            if graph.is_successful() {
                 let value = self.encode_execution_graph(graph)?;
                 self.state
                     .put(Keyspace::CompletedJobs, job_id.to_owned(), value)
@@ -292,93 +308,70 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
         Ok(())
     }
 
-    pub(crate) async fn cancel_job(
+    /// Cancel the job and return a Vec of running tasks need to cancel
+    pub(crate) async fn cancel_job(&self, job_id: &str) -> Result<Vec<RunningTaskInfo>> {
+        self.abort_job(job_id, "Cancelled".to_owned()).await
+    }
+
+    /// Abort the job and return a Vec of running tasks need to cancel
+    pub(crate) async fn abort_job(
         &self,
         job_id: &str,
-        executor_manager: &ExecutorManager,
-    ) -> Result<()> {
+        failure_reason: String,
+    ) -> Result<Vec<RunningTaskInfo>> {
         let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
-
-        let running_tasks = self
-            .get_execution_graph(job_id)
-            .await
-            .map(|graph| graph.running_tasks())
-            .unwrap_or_else(|_| vec![]);
-
-        info!(
-            "Cancelling {} running tasks for job {}",
-            running_tasks.len(),
-            job_id
-        );
-
-        self.fail_job_inner(lock, job_id, "Cancelled".to_owned())
-            .await?;
-
-        let mut tasks: HashMap<&str, Vec<protobuf::PartitionId>> = Default::default();
-
-        for (partition, executor_id) in &running_tasks {
-            if let Some(parts) = tasks.get_mut(executor_id.as_str()) {
-                parts.push(protobuf::PartitionId {
-                    job_id: job_id.to_owned(),
-                    stage_id: partition.stage_id as u32,
-                    partition_id: partition.partition_id as u32,
-                })
-            } else {
-                tasks.insert(
-                    executor_id.as_str(),
-                    vec![protobuf::PartitionId {
-                        job_id: job_id.to_owned(),
-                        stage_id: partition.stage_id as u32,
-                        partition_id: partition.partition_id as u32,
-                    }],
-                );
-            }
-        }
-
-        for (executor_id, partitions) in tasks {
-            if let Ok(mut client) = executor_manager.get_client(executor_id).await {
-                client
-                    .cancel_tasks(CancelTasksParams {
-                        partition_id: partitions,
-                    })
-                    .await?;
-            } else {
-                error!("Failed to get client for executor ID {}", executor_id)
-            }
+        if let Some(graph) = self.get_active_execution_graph(job_id).await {
+            let running_tasks = graph.read().await.running_tasks();
+            info!(
+                "Cancelling {} running tasks for job {}",
+                running_tasks.len(),
+                job_id
+            );
+            self.fail_job_state(lock, job_id, failure_reason).await?;
+            Ok(running_tasks)
+        } else {
+            // TODO listen the job state update event and fix task cancelling
+            warn!("Fail to find job {} in the cache, unable to cancel tasks for job, fail the job state only.", job_id);
+            self.fail_job_state(lock, job_id, failure_reason).await?;
+            Ok(vec![])
         }
-
-        Ok(())
     }
 
-    /// Mark a job as failed. This will create a key under the FailedJobs keyspace
+    /// Mark a unscheduled job as failed. This will create a key under the FailedJobs keyspace
     /// and remove the job from ActiveJobs or QueuedJobs
     /// TODO this should be atomic
-    pub async fn fail_job(&self, job_id: &str, error_message: String) -> Result<()> {
+    pub async fn fail_unscheduled_job(
+        &self,
+        job_id: &str,
+        failure_reason: String,
+    ) -> Result<()> {
         debug!("Moving job {} from Active or Queue to Failed", job_id);
         let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
-        self.fail_job_inner(lock, job_id, error_message).await
+        self.fail_job_state(lock, job_id, failure_reason).await
     }
 
-    async fn fail_job_inner(
+    async fn fail_job_state(
         &self,
         lock: Box<dyn Lock>,
         job_id: &str,
-        error_message: String,
+        failure_reason: String,
     ) -> Result<()> {
         with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?;
 
         let value = if let Some(graph) = self.get_active_execution_graph(job_id).await {
             let mut graph = graph.write().await;
-            graph.fail_job(error_message);
+            for stage_id in graph.running_stages() {
+                graph.fail_stage(stage_id, failure_reason.clone());
+            }
+            graph.fail_job(failure_reason);
             let graph = graph.clone();
-
             self.encode_execution_graph(graph)?
         } else {
             warn!("Fail to find job {} in the cache", job_id);
 
             let status = JobStatus {
                 status: Some(job_status::Status::Failed(FailedJob {
-                    error: error_message.clone(),
+                    error: failure_reason.clone(),
                 })),
             };
             encode_protobuf(&status)?
@@ -391,27 +384,6 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
         Ok(())
     }
 
-    /// Mark a job as failed. This will create a key under the FailedJobs keyspace
-    /// and remove the job from ActiveJobs or QueuedJobs
-    /// TODO this should be atomic
-    pub async fn fail_running_job(&self, job_id: &str) -> Result<()> {
-        if let Some(graph) = self.get_active_execution_graph(job_id).await {
-            let graph = graph.read().await.clone();
-            let value = self.encode_execution_graph(graph)?;
-
-            debug!("Moving job {} from Active to Failed", job_id);
-            let lock = self.state.lock(Keyspace::ActiveJobs, "").await?;
-            with_lock(lock, self.state.delete(Keyspace::ActiveJobs, job_id)).await?;
-            self.state
-                .put(Keyspace::FailedJobs, job_id.to_owned(), value)
-                .await?;
-        } else {
-            warn!("Fail to find job {} in the cache", job_id);
-        }
-
-        Ok(())
-    }
-
     pub async fn update_job(&self, job_id: &str) -> Result<()> {
         debug!("Update job {} in Active", job_id);
         if let Some(graph) = self.get_active_execution_graph(job_id).await {
@@ -429,16 +401,20 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
         Ok(())
     }
 
-    pub async fn executor_lost(&self, executor_id: &str) -> Result<()> {
+    /// return a Vec of running tasks need to cancel
+    pub async fn executor_lost(&self, executor_id: &str) -> Result<Vec<RunningTaskInfo>> {
+        // Collect all the running task need to cancel when there are running stages rolled back.
+        let mut running_tasks_to_cancel: Vec<RunningTaskInfo> = vec![];
         // Collect graphs we update so we can update them in storage
         let mut updated_graphs: HashMap<String, ExecutionGraph> = HashMap::new();
         {
             let job_cache = self.active_job_cache.read().await;
             for (job_id, graph) in job_cache.iter() {
                 let mut graph = graph.write().await;
-                let reset = graph.reset_stages(executor_id)?;
-                if !reset.is_empty() {
+                let reset = graph.reset_stages_on_lost_executor(executor_id)?;
+                if !reset.0.is_empty() {
                     updated_graphs.insert(job_id.to_owned(), graph.clone());
+                    running_tasks_to_cancel.extend(reset.1);
                 }
             }
         }
@@ -454,7 +430,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
                 })
                 .collect::<Result<Vec<_>>>()?;
             self.state.put_txn(txn_ops).await?;
-            Ok(())
+            Ok(running_tasks_to_cancel)
         })
         .await
     }
@@ -464,7 +440,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
     pub(crate) async fn launch_task(
         &self,
         executor: &ExecutorMetadata,
-        task: Task,
+        task: TaskDescription,
         executor_manager: &ExecutorManager,
     ) -> Result<()> {
         info!("Launching task {:?} on executor {:?}", task, executor.id);
@@ -490,7 +466,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
     pub(crate) async fn launch_task(
         &self,
         _executor: &ExecutorMetadata,
-        _task: Task,
+        _task: TaskDescription,
         _executor_manager: &ExecutorManager,
     ) -> Result<()> {
         Ok(())
@@ -509,7 +485,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
     }
 
     #[allow(dead_code)]
-    pub fn prepare_task_definition(&self, task: Task) -> Result<TaskDefinition> {
+    pub fn prepare_task_definition(
+        &self,
+        task: TaskDescription,
+    ) -> Result<TaskDefinition> {
         debug!("Preparing task definition for {:?}", task);
         let mut plan_buf: Vec<u8> = vec![];
         let plan_proto =
@@ -520,14 +499,19 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> TaskManager<T, U>
             hash_partitioning_to_proto(task.output_partitioning.as_ref())?;
 
         let task_definition = TaskDefinition {
-            task_id: Some(protobuf::PartitionId {
-                job_id: task.partition.job_id.clone(),
-                stage_id: task.partition.stage_id as u32,
-                partition_id: task.partition.partition_id as u32,
-            }),
+            task_id: task.task_id as u32,
+            task_attempt_num: task.task_attempt as u32,
+            job_id: task.partition.job_id.clone(),
+            stage_id: task.partition.stage_id as u32,
+            stage_attempt_num: task.stage_attempt_num as u32,
+            partition_id: task.partition.partition_id as u32,
             plan: plan_buf,
             output_partitioning,
             session_id: task.session_id,
+            launch_time: SystemTime::now()
+                .duration_since(UNIX_EPOCH)
+                .unwrap()
+                .as_millis() as u64,
             props: vec![],
         };
         Ok(task_definition)