You are viewing a plain text version of this content. The canonical link for it is here.
Posted to dev@nemo.apache.org by GitBox <gi...@apache.org> on 2018/06/29 06:05:07 UTC

[GitHub] seojangho closed pull request #59: [NEMO-50] Carefully retry tasks in the scheduler

seojangho closed pull request #59: [NEMO-50] Carefully retry tasks in the scheduler
URL: https://github.com/apache/incubator-nemo/pull/59
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/README.md b/README.md
index d7b1db83..9295593e 100644
--- a/README.md
+++ b/README.md
@@ -139,3 +139,7 @@ Nemo Compiler and Engine can store JSON representation of intermediate DAGs.
   	-dag_dir "./dag/als" \
   	-user_args "`pwd`/examples/resources/sample_input_als 10 3"
 ```
+
+## Speeding up builds 
+* To exclude Spark related packages: mvn clean install -T 2C -DskipTests -pl \\!compiler/frontend/spark,\\!examples/spark
+* To exclude Beam related packages: mvn clean install -T 2C -DskipTests -pl \\!compiler/frontend/beam,\\!examples/beam
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
index a45b9b60..504501bb 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/plan/StageEdge.java
@@ -95,16 +95,16 @@ public StageEdge(final String runtimeEdgeId,
   }
 
   /**
-   * @return the source vertex of the edge.
+   * @return the source IR vertex of the edge.
    */
-  public IRVertex getSrcVertex() {
+  public IRVertex getSrcIRVertex() {
     return srcVertex;
   }
 
   /**
-   * @return the destination vertex of the edge.
+   * @return the destination IR vertex of the edge.
    */
-  public IRVertex getDstVertex() {
+  public IRVertex getDstIRVertex() {
     return dstVertex;
   }
 
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
index 6bf33805..e7edbade 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/StageState.java
@@ -19,6 +19,10 @@
 
 /**
  * Represents the states and their transitions of a stage.
+ *
+ * Maintained as simple two (INCOMPLETE, COMPLETE) states to avoid ambiguity when the tasks are in different states.
+ * For example it is not clear whether a stage should be EXECUTING or SHOULD_RESTART, if one of the tasks in the stage
+ * is EXECUTING, and another is SHOULD_RESTART.
  */
 public final class StageState {
   private final StateMachine stateMachine;
@@ -31,31 +35,17 @@ private StateMachine buildTaskStateMachine() {
     final StateMachine.Builder stateMachineBuilder = StateMachine.newBuilder();
 
     // Add states
-    stateMachineBuilder.addState(State.READY, "The stage has been created.");
-    stateMachineBuilder.addState(State.EXECUTING, "The stage is executing.");
+    stateMachineBuilder.addState(State.INCOMPLETE, "Some tasks in this stage are not complete.");
     stateMachineBuilder.addState(State.COMPLETE, "All of this stage's tasks have completed.");
-    stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Stage failed, but is recoverable.");
 
     // Add transitions
-    stateMachineBuilder.addTransition(State.READY, State.EXECUTING,
-        "The stage can now schedule its tasks");
-    stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE,
-        "Recoverable failure");
-
-    stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE,
-        "All tasks complete");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE,
-        "Recoverable failure in a task");
-
-    stateMachineBuilder.addTransition(State.COMPLETE, State.FAILED_RECOVERABLE,
-        "Container on which the stage's output is stored failed");
-
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY,
-        "Recoverable stage failure");
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.EXECUTING,
-        "Recoverable stage failure");
+    stateMachineBuilder.addTransition(
+        State.INCOMPLETE, State.INCOMPLETE, "A task in the stage needs to be retried");
+    stateMachineBuilder.addTransition(State.INCOMPLETE, State.COMPLETE, "All tasks complete");
+    stateMachineBuilder.addTransition(State.COMPLETE, State.INCOMPLETE,
+        "Completed before, but a task in this stage should be retried");
 
-    stateMachineBuilder.setInitialState(State.READY);
+    stateMachineBuilder.setInitialState(State.INCOMPLETE);
 
     return stateMachineBuilder.build();
   }
@@ -68,10 +58,8 @@ public StateMachine getStateMachine() {
    * StageState.
    */
   public enum State {
-    READY,
-    EXECUTING,
-    COMPLETE,
-    FAILED_RECOVERABLE,
+    INCOMPLETE,
+    COMPLETE
   }
 
   @Override
diff --git a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
index b47696af..74b808c7 100644
--- a/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
+++ b/runtime/common/src/main/java/edu/snu/nemo/runtime/common/state/TaskState.java
@@ -31,37 +31,32 @@ private StateMachine buildTaskStateMachine() {
     final StateMachine.Builder stateMachineBuilder = StateMachine.newBuilder();
 
     // Add states
-    stateMachineBuilder.addState(State.READY, "The task has been created.");
+    stateMachineBuilder.addState(State.READY, "The task is ready to be executed.");
     stateMachineBuilder.addState(State.EXECUTING, "The task is executing.");
+    stateMachineBuilder.addState(State.ON_HOLD, "The task is paused (e.g., for dynamic optimization).");
     stateMachineBuilder.addState(State.COMPLETE, "The task has completed.");
-    stateMachineBuilder.addState(State.FAILED_RECOVERABLE, "Task failed, but is recoverable.");
-    stateMachineBuilder.addState(State.FAILED_UNRECOVERABLE, "Task failed, and is unrecoverable. The job will fail.");
-    stateMachineBuilder.addState(State.ON_HOLD, "The task is paused for dynamic optimization.");
+    stateMachineBuilder.addState(State.SHOULD_RETRY, "The task should be retried.");
+    stateMachineBuilder.addState(State.FAILED, "Task failed, and is unrecoverable. The job will fail.");
 
-    // From NOT_AVAILABLE
+    // From READY
     stateMachineBuilder.addTransition(State.READY, State.EXECUTING, "Scheduling to executor");
-    stateMachineBuilder.addTransition(State.READY, State.FAILED_RECOVERABLE,
-        "Stage Failure by a recoverable failure in another task");
 
     // From EXECUTING
     stateMachineBuilder.addTransition(State.EXECUTING, State.COMPLETE, "Task completed normally");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_UNRECOVERABLE, "Unrecoverable failure");
-    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED_RECOVERABLE, "Recoverable failure");
     stateMachineBuilder.addTransition(State.EXECUTING, State.ON_HOLD, "Task paused for dynamic optimization");
+    stateMachineBuilder.addTransition(State.EXECUTING, State.SHOULD_RETRY, "Did not complete, should be retried");
+    stateMachineBuilder.addTransition(State.EXECUTING, State.FAILED, "Unrecoverable failure");
 
     // From ON HOLD
     stateMachineBuilder.addTransition(State.ON_HOLD, State.COMPLETE, "Task completed after being on hold");
-    stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_UNRECOVERABLE, "Unrecoverable failure");
-    stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED_RECOVERABLE, "Recoverable failure");
+    stateMachineBuilder.addTransition(State.ON_HOLD, State.SHOULD_RETRY, "Did not complete, should be retried");
+    stateMachineBuilder.addTransition(State.ON_HOLD, State.FAILED, "Unrecoverable failure");
 
     // From COMPLETE
-    stateMachineBuilder.addTransition(State.COMPLETE, State.EXECUTING, "Completed before, but re-execute");
-    stateMachineBuilder.addTransition(State.COMPLETE, State.FAILED_RECOVERABLE,
-        "Recoverable failure in a task/Container failure");
+    stateMachineBuilder.addTransition(State.COMPLETE, State.SHOULD_RETRY, "Completed before, but should be retried");
 
-
-    // From FAILED_RECOVERABLE
-    stateMachineBuilder.addTransition(State.FAILED_RECOVERABLE, State.READY, "Recovered from failure and is ready");
+    // From SHOULD_RETRY
+    stateMachineBuilder.addTransition(State.SHOULD_RETRY, State.READY, "Ready to be retried");
 
     stateMachineBuilder.setInitialState(State.READY);
     return stateMachineBuilder.build();
@@ -77,19 +72,18 @@ public StateMachine getStateMachine() {
   public enum State {
     READY,
     EXECUTING,
-    COMPLETE,
-    FAILED_RECOVERABLE,
-    FAILED_UNRECOVERABLE,
     ON_HOLD, // for dynamic optimization
+    COMPLETE,
+    SHOULD_RETRY,
+    FAILED,
   }
 
   /**
    * Causes of a recoverable failure.
    */
-  public enum RecoverableFailureCause {
+  public enum RecoverableTaskFailureCause {
     INPUT_READ_FAILURE, // Occurs when a task is unable to read its input block
     OUTPUT_WRITE_FAILURE, // Occurs when a task successfully generates its output, but is unable to write it
-    CONTAINER_FAILURE // When a REEF evaluator fails
   }
 
   @Override
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
index bf013504..96456218 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/TaskStateManager.java
@@ -63,7 +63,7 @@ public TaskStateManager(final Task task,
    */
   public synchronized void onTaskStateChanged(final TaskState.State newState,
                                               final Optional<String> vertexPutOnHold,
-                                              final Optional<TaskState.RecoverableFailureCause> cause) {
+                                              final Optional<TaskState.RecoverableTaskFailureCause> cause) {
     final Map<String, Object> metric = new HashMap<>();
 
     switch (newState) {
@@ -80,13 +80,13 @@ public synchronized void onTaskStateChanged(final TaskState.State newState,
         metricCollector.endMeasurement(taskId, metric);
         notifyTaskStateToMaster(newState, Optional.empty(), cause);
         break;
-      case FAILED_RECOVERABLE:
+      case SHOULD_RETRY:
         LOG.debug("Task ID {} failed (recoverable).", this.taskId);
         metric.put("ToState", newState);
         metricCollector.endMeasurement(taskId, metric);
         notifyTaskStateToMaster(newState, Optional.empty(), cause);
         break;
-      case FAILED_UNRECOVERABLE:
+      case FAILED:
         LOG.debug("Task ID {} failed (unrecoverable).", this.taskId);
         metric.put("ToState", newState);
         metricCollector.endMeasurement(taskId, metric);
@@ -109,7 +109,7 @@ public synchronized void onTaskStateChanged(final TaskState.State newState,
    */
   private void notifyTaskStateToMaster(final TaskState.State newState,
                                        final Optional<String> vertexPutOnHold,
-                                       final Optional<TaskState.RecoverableFailureCause> cause) {
+                                       final Optional<TaskState.RecoverableTaskFailureCause> cause) {
     final ControlMessage.TaskStateChangedMsg.Builder msgBuilder =
         ControlMessage.TaskStateChangedMsg.newBuilder()
             .setExecutorId(executorId)
@@ -141,9 +141,9 @@ private void notifyTaskStateToMaster(final TaskState.State newState,
         return ControlMessage.TaskStateFromExecutor.EXECUTING;
       case COMPLETE:
         return ControlMessage.TaskStateFromExecutor.COMPLETE;
-      case FAILED_RECOVERABLE:
+      case SHOULD_RETRY:
         return ControlMessage.TaskStateFromExecutor.FAILED_RECOVERABLE;
-      case FAILED_UNRECOVERABLE:
+      case FAILED:
         return ControlMessage.TaskStateFromExecutor.FAILED_UNRECOVERABLE;
       case ON_HOLD:
         return ControlMessage.TaskStateFromExecutor.ON_HOLD;
@@ -153,7 +153,7 @@ private void notifyTaskStateToMaster(final TaskState.State newState,
   }
 
   private ControlMessage.RecoverableFailureCause convertFailureCause(
-      final TaskState.RecoverableFailureCause cause) {
+      final TaskState.RecoverableTaskFailureCause cause) {
     switch (cause) {
       case INPUT_READ_FAILURE:
         return ControlMessage.RecoverableFailureCause.InputReadFailure;
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
index 52c55fdd..144ded08 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/data/BlockManagerWorker.java
@@ -135,11 +135,10 @@ public Block createBlock(final String blockId,
    * @param keyRange   the key range descriptor.
    * @return the result data in the block.
    */
-  private CompletableFuture<DataUtil.IteratorWithNumBytes> retrieveDataFromBlock(
+  private CompletableFuture<DataUtil.IteratorWithNumBytes> getDataFromLocalBlock(
       final String blockId,
       final InterTaskDataStoreProperty.Value blockStore,
       final KeyRange keyRange) {
-    LOG.info("RetrieveDataFromBlock: {}", blockId);
     final BlockStore store = getBlockStore(blockStore);
 
     // First, try to fetch the block from local BlockStore.
@@ -229,7 +228,7 @@ public Block createBlock(final String blockId,
       final String targetExecutorId = blockLocationInfoMsg.getOwnerExecutorId();
       if (targetExecutorId.equals(executorId) || targetExecutorId.equals(REMOTE_FILE_STORE)) {
         // Block resides in the evaluator
-        return retrieveDataFromBlock(blockId, blockStore, keyRange);
+        return getDataFromLocalBlock(blockId, blockStore, keyRange);
       } else {
         final ByteTransferContextDescriptor descriptor = ByteTransferContextDescriptor.newBuilder()
             .setBlockId(blockId)
diff --git a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
index fe20ade9..c320785a 100644
--- a/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
+++ b/runtime/executor/src/main/java/edu/snu/nemo/runtime/executor/task/TaskExecutor.java
@@ -217,7 +217,7 @@ public void execute() {
       doExecute();
     } catch (Throwable throwable) {
       // ANY uncaught throwable is reported to the master
-      taskStateManager.onTaskStateChanged(TaskState.State.FAILED_UNRECOVERABLE, Optional.empty(), Optional.empty());
+      taskStateManager.onTaskStateChanged(TaskState.State.FAILED, Optional.empty(), Optional.empty());
       LOG.error(ExceptionUtils.getStackTrace(throwable));
     }
   }
@@ -319,8 +319,8 @@ private boolean handleDataFetchers(final List<DataFetcher> fetchers) {
         try {
           element = dataFetcher.fetchDataElement();
         } catch (IOException e) {
-          taskStateManager.onTaskStateChanged(TaskState.State.FAILED_RECOVERABLE,
-              Optional.empty(), Optional.of(TaskState.RecoverableFailureCause.INPUT_READ_FAILURE));
+          taskStateManager.onTaskStateChanged(TaskState.State.SHOULD_RETRY,
+              Optional.empty(), Optional.of(TaskState.RecoverableTaskFailureCause.INPUT_READ_FAILURE));
           LOG.error("{} Execution Failed (Recoverable: input read failure)! Exception: {}", taskId, e.toString());
           return false;
         }
@@ -366,9 +366,9 @@ private boolean handleDataFetchers(final List<DataFetcher> fetchers) {
                                                  final DataTransferFactory dataTransferFactory) {
     return inEdgesFromParentTasks
         .stream()
-        .filter(inEdge -> inEdge.getDstVertex().getId().equals(irVertex.getId()))
+        .filter(inEdge -> inEdge.getDstIRVertex().getId().equals(irVertex.getId()))
         .map(inEdgeForThisVertex -> dataTransferFactory
-            .createReader(taskIndex, inEdgeForThisVertex.getSrcVertex(), inEdgeForThisVertex))
+            .createReader(taskIndex, inEdgeForThisVertex.getSrcIRVertex(), inEdgeForThisVertex))
         .collect(Collectors.toList());
   }
 
@@ -378,9 +378,9 @@ private boolean handleDataFetchers(final List<DataFetcher> fetchers) {
                                                     final DataTransferFactory dataTransferFactory) {
     return outEdgesToChildrenTasks
         .stream()
-        .filter(outEdge -> outEdge.getSrcVertex().getId().equals(irVertex.getId()))
+        .filter(outEdge -> outEdge.getSrcIRVertex().getId().equals(irVertex.getId()))
         .map(outEdgeForThisVertex -> dataTransferFactory
-            .createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstVertex(), outEdgeForThisVertex))
+            .createWriter(irVertex, taskIndex, outEdgeForThisVertex.getDstIRVertex(), outEdgeForThisVertex))
         .collect(Collectors.toList());
   }
 
diff --git a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
index 0c587445..96b755d2 100644
--- a/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
+++ b/runtime/executor/src/test/java/edu/snu/nemo/runtime/executor/task/TaskExecutorTest.java
@@ -273,15 +273,15 @@ public void testTwoOperatorsWithSideInput() throws Exception {
 
   private StageEdge mockStageEdgeFrom(final IRVertex irVertex) {
     final StageEdge edge = mock(StageEdge.class);
-    when(edge.getSrcVertex()).thenReturn(irVertex);
-    when(edge.getDstVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
+    when(edge.getSrcIRVertex()).thenReturn(irVertex);
+    when(edge.getDstIRVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
     return edge;
   }
 
   private StageEdge mockStageEdgeTo(final IRVertex irVertex) {
     final StageEdge edge = mock(StageEdge.class);
-    when(edge.getSrcVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
-    when(edge.getDstVertex()).thenReturn(irVertex);
+    when(edge.getSrcIRVertex()).thenReturn(new OperatorVertex(new RelayTransform()));
+    when(edge.getDstIRVertex()).thenReturn(irVertex);
     return edge;
   }
 
diff --git a/runtime/master/pom.xml b/runtime/master/pom.xml
index 75c7c251..6cb5abc7 100644
--- a/runtime/master/pom.xml
+++ b/runtime/master/pom.xml
@@ -61,5 +61,15 @@ limitations under the License.
             <artifactId>jackson-databind</artifactId>
             <version>${jackson.version}</version>
         </dependency>
+        <dependency>
+            <!--
+            This is needed to view the logs when running unit tests.
+            See https://dzone.com/articles/how-configure-slf4j-different for details.
+            -->
+            <groupId>org.slf4j</groupId>
+            <artifactId>slf4j-simple</artifactId>
+            <version>1.6.2</version>
+            <scope>test</scope>
+        </dependency>
     </dependencies>
 </project>
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
index 1da2e722..da7c1ef6 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/BlockManagerMaster.java
@@ -15,14 +15,20 @@
  */
 package edu.snu.nemo.runtime.master;
 
+import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.exception.IllegalMessageException;
 import edu.snu.nemo.common.exception.UnknownExecutionStateException;
+import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.runtime.common.comm.ControlMessage;
 import edu.snu.nemo.runtime.common.exception.AbsentBlockException;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.message.MessageContext;
 import edu.snu.nemo.runtime.common.message.MessageEnvironment;
 import edu.snu.nemo.runtime.common.message.MessageListener;
+import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
+import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
+import edu.snu.nemo.runtime.common.plan.Stage;
+import edu.snu.nemo.runtime.common.plan.StageEdge;
 import edu.snu.nemo.runtime.common.state.BlockState;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -37,6 +43,7 @@
 import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
+import java.util.stream.IntStream;
 
 import org.apache.reef.annotations.audience.DriverSide;
 import org.slf4j.Logger;
@@ -74,6 +81,36 @@ private BlockManagerMaster(final MessageEnvironment masterMessageEnvironment) {
     this.lock = new ReentrantReadWriteLock();
   }
 
+  public void initialize(final PhysicalPlan physicalPlan) {
+    final DAG<Stage, StageEdge> stageDAG = physicalPlan.getStageDAG();
+    stageDAG.topologicalDo(stage -> {
+      final List<String> taskIdsForStage = stage.getTaskIds();
+      final List<StageEdge> stageOutgoingEdges = stageDAG.getOutgoingEdgesOf(stage);
+
+      // Initialize states for blocks of inter-stage edges
+      stageOutgoingEdges.forEach(stageEdge -> {
+        final int srcParallelism = taskIdsForStage.size();
+        IntStream.range(0, srcParallelism).forEach(srcTaskIdx -> {
+          final String blockId = RuntimeIdGenerator.generateBlockId(stageEdge.getId(), srcTaskIdx);
+          initializeState(blockId, taskIdsForStage.get(srcTaskIdx));
+        });
+      });
+
+      // Initialize states for blocks of stage internal edges
+      taskIdsForStage.forEach(taskId -> {
+        final DAG<IRVertex, RuntimeEdge<IRVertex>> taskInternalDag = stage.getIRDAG();
+        taskInternalDag.getVertices().forEach(task -> {
+          final List<RuntimeEdge<IRVertex>> internalOutgoingEdges = taskInternalDag.getOutgoingEdgesOf(task);
+          internalOutgoingEdges.forEach(taskRuntimeEdge -> {
+            final int srcTaskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
+            final String blockId = RuntimeIdGenerator.generateBlockId(taskRuntimeEdge.getId(), srcTaskIdx);
+            initializeState(blockId, taskId);
+          });
+        });
+      });
+    });
+  }
+
   /**
    * Initializes the states of a block which will be produced by producer task(s).
    *
@@ -132,8 +169,7 @@ public BlockLocationRequestHandler getBlockLocationHandler(final String blockId)
     final Lock readLock = lock.readLock();
     readLock.lock();
     try {
-      final BlockState.State state =
-          (BlockState.State) getBlockState(blockId).getStateMachine().getCurrentState();
+      final BlockState.State state = getBlockState(blockId);
       switch (state) {
         case IN_PROGRESS:
         case AVAILABLE:
@@ -174,6 +210,16 @@ public BlockLocationRequestHandler getBlockLocationHandler(final String blockId)
     }
   }
 
+  public Set<String> getIdsOfBlocksProducedBy(final String taskId) {
+    final Lock readLock = lock.readLock();
+    readLock.lock();
+    try {
+      return producerTaskIdToBlockIds.get(taskId);
+    } finally {
+      readLock.unlock();
+    }
+  }
+
   /**
    * To be called when a potential producer task is scheduled.
    * @param scheduledTaskId the ID of the scheduled task.
@@ -256,11 +302,11 @@ public void onProducerTaskFailed(final String failedTaskId) {
    * @return the {@link BlockState} of a block.
    */
   @VisibleForTesting
-  BlockState getBlockState(final String blockId) {
+  public BlockState.State getBlockState(final String blockId) {
     final Lock readLock = lock.readLock();
     readLock.lock();
     try {
-      return blockIdToMetadata.get(blockId).getBlockState();
+      return (BlockState.State) blockIdToMetadata.get(blockId).getBlockState().getStateMachine().getCurrentState();
     } finally {
       readLock.unlock();
     }
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
index 220e5854..2a0ef882 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/JobStateManager.java
@@ -15,18 +15,15 @@
  */
 package edu.snu.nemo.runtime.master;
 
+import com.google.common.annotations.VisibleForTesting;
 import edu.snu.nemo.common.exception.IllegalStateTransitionException;
 import edu.snu.nemo.common.exception.SchedulingException;
 import edu.snu.nemo.common.exception.UnknownExecutionStateException;
 import edu.snu.nemo.common.StateMachine;
-import edu.snu.nemo.common.dag.DAG;
-import edu.snu.nemo.common.ir.vertex.IRVertex;
 import edu.snu.nemo.runtime.common.metric.MetricDataBuilder;
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
 import edu.snu.nemo.runtime.common.plan.Stage;
-import edu.snu.nemo.runtime.common.plan.StageEdge;
-import edu.snu.nemo.runtime.common.plan.RuntimeEdge;
 import edu.snu.nemo.runtime.common.state.JobState;
 import edu.snu.nemo.runtime.common.state.StageState;
 
@@ -45,7 +42,6 @@
 import org.slf4j.LoggerFactory;
 
 import javax.annotation.concurrent.ThreadSafe;
-import java.util.stream.IntStream;
 
 import static edu.snu.nemo.common.dag.DAG.EMPTY_DAG_DIRECTORY;
 
@@ -94,7 +90,6 @@
   private final Map<String, MetricDataBuilder> metricDataBuilderMap;
 
   public JobStateManager(final PhysicalPlan physicalPlan,
-                         final BlockManagerMaster blockManagerMaster,
                          final MetricMessageHandler metricMessageHandler,
                          final int maxScheduleAttempt) {
     this.jobId = physicalPlan.getId();
@@ -109,7 +104,6 @@ public JobStateManager(final PhysicalPlan physicalPlan,
     this.jobFinishedCondition = finishLock.newCondition();
     this.metricDataBuilderMap = new HashMap<>();
     initializeComputationStates();
-    initializePartitionStates(blockManagerMaster);
   }
 
   /**
@@ -128,36 +122,6 @@ private void initializeComputationStates() {
     });
   }
 
-  private void initializePartitionStates(final BlockManagerMaster blockManagerMaster) {
-    final DAG<Stage, StageEdge> stageDAG = physicalPlan.getStageDAG();
-    stageDAG.topologicalDo(stage -> {
-      final List<String> taskIdsForStage = stage.getTaskIds();
-      final List<StageEdge> stageOutgoingEdges = stageDAG.getOutgoingEdgesOf(stage);
-
-      // Initialize states for blocks of inter-stage edges
-      stageOutgoingEdges.forEach(stageEdge -> {
-        final int srcParallelism = taskIdsForStage.size();
-        IntStream.range(0, srcParallelism).forEach(srcTaskIdx -> {
-          final String blockId = RuntimeIdGenerator.generateBlockId(stageEdge.getId(), srcTaskIdx);
-          blockManagerMaster.initializeState(blockId, taskIdsForStage.get(srcTaskIdx));
-        });
-      });
-
-      // Initialize states for blocks of stage internal edges
-      taskIdsForStage.forEach(taskId -> {
-        final DAG<IRVertex, RuntimeEdge<IRVertex>> taskInternalDag = stage.getIRDAG();
-        taskInternalDag.getVertices().forEach(task -> {
-          final List<RuntimeEdge<IRVertex>> internalOutgoingEdges = taskInternalDag.getOutgoingEdgesOf(task);
-          internalOutgoingEdges.forEach(taskRuntimeEdge -> {
-            final int srcTaskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
-            final String blockId = RuntimeIdGenerator.generateBlockId(taskRuntimeEdge.getId(), srcTaskIdx);
-            blockManagerMaster.initializeState(blockId, taskId);
-          });
-        });
-      });
-    });
-  }
-
   /**
    * Updates the state of a task.
    * Task state changes can occur both in master and executor.
@@ -183,8 +147,8 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState
     switch (newTaskState) {
       case ON_HOLD:
       case COMPLETE:
-      case FAILED_UNRECOVERABLE:
-      case FAILED_RECOVERABLE:
+      case FAILED:
+      case SHOULD_RETRY:
         metric.put("ToState", newTaskState);
         endMeasurement(taskId, metric);
         break;
@@ -213,32 +177,33 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState
         .map(this::getTaskState)
         .filter(state -> state.equals(TaskState.State.COMPLETE) || state.equals(TaskState.State.ON_HOLD))
         .count();
+    if (newTaskState.equals(TaskState.State.COMPLETE)) {
+      // Log not-yet-completed tasks for us to track progress
+      LOG.info("{} completed: {} Task(s) remaining in this stage",
+          taskId, tasksOfThisStage.size() - numOfCompletedOrOnHoldTasksInThisStage);
+    }
     switch (newTaskState) {
-      case READY:
-        onStageStateChanged(stageId, StageState.State.READY);
-        break;
-      case EXECUTING:
-        onStageStateChanged(stageId, StageState.State.EXECUTING);
-        break;
-      case FAILED_RECOVERABLE:
-        onStageStateChanged(stageId, StageState.State.FAILED_RECOVERABLE);
+      // INCOMPLETE stage
+      case SHOULD_RETRY:
+        onStageStateChanged(stageId, StageState.State.INCOMPLETE);
         break;
+
+      // COMPLETE stage
       case COMPLETE:
       case ON_HOLD:
         if (numOfCompletedOrOnHoldTasksInThisStage == tasksOfThisStage.size()) {
           onStageStateChanged(stageId, StageState.State.COMPLETE);
         }
         break;
-      case FAILED_UNRECOVERABLE:
+
+      // Doesn't affect StageState
+      case READY:
+      case EXECUTING:
+      case FAILED:
         break;
       default:
         throw new UnknownExecutionStateException(new Throwable("This task state is unknown"));
     }
-
-    // Log not-yet-completed tasks for us to track progress
-    if (newTaskState.equals(TaskState.State.COMPLETE)) {
-      LOG.info("{}: {} Task(s) to go", stageId, tasksOfThisStage.size() - numOfCompletedOrOnHoldTasksInThisStage);
-    }
   }
 
   /**
@@ -248,11 +213,6 @@ public synchronized void onTaskStateChanged(final String taskId, final TaskState
    * @param newStageState of the stage.
    */
   private void onStageStateChanged(final String stageId, final StageState.State newStageState) {
-    if (newStageState.equals(getStageState(stageId))) {
-      // Ignore duplicate state updates
-      return;
-    }
-
     // Change stage state
     final StateMachine stageStateMachine = idToStageStates.get(stageId).getStateMachine();
     LOG.debug("Stage State Transition: id {} from {} to {}",
@@ -261,7 +221,7 @@ private void onStageStateChanged(final String stageId, final StageState.State ne
 
     // Metric handling
     final Map<String, Object> metric = new HashMap<>();
-    if (newStageState == StageState.State.EXECUTING) {
+    if (newStageState == StageState.State.INCOMPLETE) {
       metric.put("FromState", newStageState);
       beginMeasurement(stageId, metric);
     } else if (newStageState == StageState.State.COMPLETE) {
@@ -269,16 +229,9 @@ private void onStageStateChanged(final String stageId, final StageState.State ne
       endMeasurement(stageId, metric);
     }
 
-    // Change job state if needed
+    // Job becomse COMPLETE
     final boolean allStagesCompleted = idToStageStates.values().stream().allMatch(state ->
         state.getStateMachine().getCurrentState().equals(StageState.State.COMPLETE));
-
-    // (1) Job becomes EXECUTING if not already
-    if (newStageState.equals(StageState.State.EXECUTING)
-        && !getJobState().equals(JobState.State.EXECUTING)) {
-      onJobStateChanged(JobState.State.EXECUTING);
-    }
-    // (2) Job becomes COMPLETE
     if (allStagesCompleted) {
       onJobStateChanged(JobState.State.COMPLETE);
     }
@@ -290,20 +243,15 @@ private void onStageStateChanged(final String stageId, final StageState.State ne
    * @param newState of the job.
    */
   private void onJobStateChanged(final JobState.State newState) {
-    if (newState.equals(getJobState())) {
-      // Ignore duplicate state updates
-      return;
-    }
-
     jobState.getStateMachine().setState(newState);
 
     final Map<String, Object> metric = new HashMap<>();
     if (newState == JobState.State.EXECUTING) {
-      LOG.debug("Executing Job ID {}...", this.jobId);
+      LOG.info("Executing Job ID {}...", this.jobId);
       metric.put("FromState", newState);
       beginMeasurement(jobId, metric);
     } else if (newState == JobState.State.COMPLETE || newState == JobState.State.FAILED) {
-      LOG.debug("Job ID {} {}!", new Object[]{jobId, newState});
+      LOG.info("Job ID {} {}!", new Object[]{jobId, newState});
 
       // Awake all threads waiting the finish of this job.
       finishLock.lock();
@@ -391,6 +339,11 @@ public synchronized int getTaskAttempt(final String taskId) {
     }
   }
 
+  @VisibleForTesting
+  public synchronized Map<String, TaskState> getAllTaskStates() {
+    return idToTaskStates;
+  }
+
   /**
    * Begins recording the start time of this metric measurement, in addition to the metric given.
    * This method ensures thread-safety by synchronizing its callers.
@@ -435,7 +388,7 @@ public void storeJSON(final String directory, final String suffix) {
     file.getParentFile().mkdirs();
     try (final PrintWriter printWriter = new PrintWriter(file)) {
       printWriter.println(toStringWithPhysicalPlan());
-      LOG.debug(String.format("JSON representation of job state for %s(%s) was saved to %s",
+      LOG.info(String.format("JSON representation of job state for %s(%s) was saved to %s",
           jobId, suffix, file.getPath()));
     } catch (final IOException e) {
       LOG.warn(String.format("Cannot store JSON representation of job state for %s(%s) to %s: %s",
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
index 4529a0b2..61ffd203 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/RuntimeMaster.java
@@ -127,8 +127,8 @@ public RuntimeMaster(final Scheduler scheduler,
     final Callable<Pair<JobStateManager, ScheduledExecutorService>> jobExecutionCallable = () -> {
       this.irVertices.addAll(plan.getIdToIRVertex().values());
       try {
-        final JobStateManager jobStateManager =
-            new JobStateManager(plan, blockManagerMaster, metricMessageHandler, maxScheduleAttempt);
+        blockManagerMaster.initialize(plan);
+        final JobStateManager jobStateManager = new JobStateManager(plan, metricMessageHandler, maxScheduleAttempt);
         scheduler.scheduleJob(plan, jobStateManager);
         final ScheduledExecutorService dagLoggingExecutor = scheduleDagLogging(jobStateManager);
         return Pair.of(jobStateManager, dagLoggingExecutor);
@@ -363,9 +363,9 @@ private void accumulateBarrierMetric(final List<ControlMessage.PartitionSizeEntr
       case COMPLETE:
         return COMPLETE;
       case FAILED_RECOVERABLE:
-        return TaskState.State.FAILED_RECOVERABLE;
+        return TaskState.State.SHOULD_RETRY;
       case FAILED_UNRECOVERABLE:
-        return TaskState.State.FAILED_UNRECOVERABLE;
+        return TaskState.State.FAILED;
       case ON_HOLD:
         return ON_HOLD;
       default:
@@ -373,13 +373,13 @@ private void accumulateBarrierMetric(final List<ControlMessage.PartitionSizeEntr
     }
   }
 
-  private TaskState.RecoverableFailureCause convertFailureCause(
+  private TaskState.RecoverableTaskFailureCause convertFailureCause(
       final ControlMessage.RecoverableFailureCause cause) {
     switch (cause) {
       case InputReadFailure:
-        return TaskState.RecoverableFailureCause.INPUT_READ_FAILURE;
+        return TaskState.RecoverableTaskFailureCause.INPUT_READ_FAILURE;
       case OutputWriteFailure:
-        return TaskState.RecoverableFailureCause.OUTPUT_WRITE_FAILURE;
+        return TaskState.RecoverableTaskFailureCause.OUTPUT_WRITE_FAILURE;
       default:
         throw new UnknownFailureCauseException(
             new Throwable("The failure cause for the recoverable failure is unknown"));
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
index 40094243..880c8d0e 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobScheduler.java
@@ -15,6 +15,7 @@
  */
 package edu.snu.nemo.runtime.master.scheduler;
 
+import com.google.common.collect.Sets;
 import edu.snu.nemo.common.Pair;
 import edu.snu.nemo.common.dag.DAG;
 import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper;
@@ -23,6 +24,7 @@
 import edu.snu.nemo.runtime.common.RuntimeIdGenerator;
 import edu.snu.nemo.runtime.common.eventhandler.DynamicOptimizationEvent;
 import edu.snu.nemo.runtime.common.plan.*;
+import edu.snu.nemo.runtime.common.state.BlockState;
 import edu.snu.nemo.runtime.common.state.TaskState;
 import edu.snu.nemo.runtime.master.eventhandler.UpdatePhysicalPlanEventHandler;
 import edu.snu.nemo.common.exception.*;
@@ -39,12 +41,10 @@
 import javax.inject.Inject;
 import java.util.*;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import org.slf4j.Logger;
 
-import static edu.snu.nemo.runtime.common.state.TaskState.State.ON_HOLD;
-import static edu.snu.nemo.runtime.common.state.TaskState.State.READY;
-
 /**
  * (CONCURRENCY) Only a single dedicated thread should use the public methods of this class.
  * (i.e., runtimeMasterThread in RuntimeMaster)
@@ -74,7 +74,7 @@
    */
   private PhysicalPlan physicalPlan;
   private JobStateManager jobStateManager;
-  private int initialScheduleGroup;
+  private List<List<Stage>> sortedScheduleGroups;
 
   @Inject
   public BatchSingleJobScheduler(final SchedulerRunner schedulerRunner,
@@ -108,16 +108,19 @@ public void scheduleJob(final PhysicalPlan physicalPlanOfJob, final JobStateMana
     this.physicalPlan = physicalPlanOfJob;
     this.jobStateManager = jobStateManagerOfJob;
 
-    schedulerRunner.scheduleJob(jobStateManagerOfJob);
-    schedulerRunner.runSchedulerThread();
-
-    LOG.info("Job to schedule: {}", physicalPlanOfJob.getId());
+    schedulerRunner.run(jobStateManager);
+    LOG.info("Job to schedule: {}", this.physicalPlan.getId());
 
-    this.initialScheduleGroup = physicalPlanOfJob.getStageDAG().getVertices().stream()
-        .mapToInt(stage -> stage.getScheduleGroup())
-        .min().getAsInt();
+    this.sortedScheduleGroups = this.physicalPlan.getStageDAG().getVertices()
+        .stream()
+        .collect(Collectors.groupingBy(Stage::getScheduleGroup))
+        .entrySet()
+        .stream()
+        .sorted(Map.Entry.comparingByKey())
+        .map(Map.Entry::getValue)
+        .collect(Collectors.toList());
 
-    scheduleNextScheduleGroup(initialScheduleGroup);
+    doSchedule();
   }
 
   @Override
@@ -127,6 +130,7 @@ public void updateJob(final String jobId, final PhysicalPlan newPhysicalPlan, fi
     this.physicalPlan = newPhysicalPlan;
     if (taskInfo != null) {
       onTaskExecutionComplete(taskInfo.left(), taskInfo.right(), true);
+      doSchedule();
     }
   }
 
@@ -147,7 +151,7 @@ public void onTaskStateReportFromExecutor(final String executorId,
                                             final int taskAttemptIndex,
                                             final TaskState.State newState,
                                             @Nullable final String vertexPutOnHold,
-                                            final TaskState.RecoverableFailureCause failureCause) {
+                                            final TaskState.RecoverableTaskFailureCause failureCause) {
     final int currentTaskAttemptIndex = jobStateManager.getTaskAttempt(taskId);
 
     if (taskAttemptIndex == currentTaskAttemptIndex) {
@@ -155,15 +159,16 @@ public void onTaskStateReportFromExecutor(final String executorId,
       jobStateManager.onTaskStateChanged(taskId, newState);
       switch (newState) {
         case COMPLETE:
-          onTaskExecutionComplete(executorId, taskId);
+          onTaskExecutionComplete(executorId, taskId, false);
           break;
-        case FAILED_RECOVERABLE:
+        case SHOULD_RETRY:
+          // SHOULD_RETRY from an executor means that the task ran into a recoverable failure
           onTaskExecutionFailedRecoverable(executorId, taskId, failureCause);
           break;
         case ON_HOLD:
           onTaskExecutionOnHold(executorId, taskId, vertexPutOnHold);
           break;
-        case FAILED_UNRECOVERABLE:
+        case FAILED:
           throw new UnrecoverableFailureException(new Exception(new StringBuffer().append("The job failed on Task #")
               .append(taskId).append(" in Executor ").append(executorId).toString()));
         case READY:
@@ -173,11 +178,43 @@ public void onTaskStateReportFromExecutor(final String executorId,
         default:
           throw new UnknownExecutionStateException(new Exception("This TaskState is unknown: " + newState));
       }
+
+      // Invoke doSchedule()
+      switch (newState) {
+        case COMPLETE:
+        case ON_HOLD:
+          // If the stage has completed
+          final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
+          if (jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) {
+            if (!jobStateManager.isJobDone()) {
+              doSchedule();
+            }
+          }
+          break;
+        case SHOULD_RETRY:
+          // Retry the failed task
+          doSchedule();
+          break;
+        default:
+          break;
+      }
+
+      // Invoke schedulerRunner.onExecutorSlotAvailable()
+      switch (newState) {
+        // These three states mean that a slot is made available.
+        case COMPLETE:
+        case ON_HOLD:
+        case SHOULD_RETRY:
+          schedulerRunner.onExecutorSlotAvailable();
+          break;
+        default:
+          break;
+      }
     } else if (taskAttemptIndex < currentTaskAttemptIndex) {
-      // Do not change state, as this notification is for a previous task attempt.
+      // Do not change state, as this report is from a previous task attempt.
       // For example, the master can receive a notification that an executor has been removed,
       // and then a notification that the task that was running in the removed executor has been completed.
-      // In this case, if we do not consider the attempt number, the state changes from FAILED_RECOVERABLE to COMPLETED,
+      // In this case, if we do not consider the attempt number, the state changes from SHOULD_RETRY to COMPLETED,
       // which is illegal.
       LOG.info("{} state change to {} arrived late, we will ignore this.", new Object[]{taskId, newState});
     } else {
@@ -187,35 +224,33 @@ public void onTaskStateReportFromExecutor(final String executorId,
 
   @Override
   public void onExecutorAdded(final ExecutorRepresenter executorRepresenter) {
+    LOG.info("{} added", executorRepresenter.getExecutorId());
     executorRegistry.registerExecutor(executorRepresenter);
-    schedulerRunner.onAnExecutorAvailable();
+    schedulerRunner.onExecutorSlotAvailable();
   }
 
   @Override
   public void onExecutorRemoved(final String executorId) {
-    final Set<String> tasksToReExecute = new HashSet<>();
-    // Tasks for lost blocks
-    tasksToReExecute.addAll(blockManagerMaster.removeWorker(executorId));
+    blockManagerMaster.removeWorker(executorId);
 
-    // Tasks executing on the removed executor
+    // These are tasks that were running at the time of executor removal.
+    final Set<String> interruptedTasks = new HashSet<>();
     executorRegistry.updateExecutor(executorId, (executor, state) -> {
-      tasksToReExecute.addAll(executor.onExecutorFailed());
+      interruptedTasks.addAll(executor.onExecutorFailed());
       return Pair.of(executor, ExecutorRegistry.ExecutorState.FAILED);
     });
 
-    tasksToReExecute.forEach(failedTaskId -> {
-      final int attemptIndex = jobStateManager.getTaskAttempt(failedTaskId);
-      onTaskStateReportFromExecutor(executorId, failedTaskId, attemptIndex, TaskState.State.FAILED_RECOVERABLE,
-          null, TaskState.RecoverableFailureCause.CONTAINER_FAILURE);
-    });
+    // We need to retry the interrupted tasks, and also recover the tasks' missing input blocks if needed.
+    final Set<String> tasksToReExecute =
+        Sets.union(interruptedTasks, recursivelyGetParentTasksForLostBlocks(interruptedTasks));
 
-    if (!tasksToReExecute.isEmpty()) {
-      // Schedule a stage after marking the necessary tasks to failed_recoverable.
-      // The stage for one of the tasks that failed is a starting point to look
-      // for the next stage to be scheduled.
-      scheduleNextScheduleGroup(getScheduleGroupOfStage(
-          RuntimeIdGenerator.getStageIdFromTaskId(tasksToReExecute.iterator().next())));
-    }
+    // Report SHOULD_RETRY tasks so they can be re-scheduled
+    LOG.info("{} removed: {} will be retried", executorId, tasksToReExecute);
+    tasksToReExecute.forEach(
+        taskToReExecute -> jobStateManager.onTaskStateChanged(taskToReExecute, TaskState.State.SHOULD_RETRY));
+
+    // Trigger the scheduling of SHOULD_RETRY tasks in the earliest scheduleGroup
+    doSchedule();
   }
 
   @Override
@@ -224,108 +259,45 @@ public void terminate() {
     this.executorRegistry.terminate();
   }
 
+  ////////////////////////////////////////////////////////////////////// Key methods for scheduling
+
   /**
-   * Schedules the next schedule group to execute.
-   * @param referenceIndex of the schedule group.
+   * The main entry point for task scheduling.
+   * This operation can be invoked at any point during job execution, as it is designed to be free of side-effects,
+   * and integrate well with {@link PendingTaskCollectionPointer} and {@link SchedulerRunner}.
    */
-  private void scheduleNextScheduleGroup(final int referenceIndex) {
-    final Optional<List<Stage>> nextScheduleGroupToSchedule = selectNextScheduleGroupToSchedule(referenceIndex);
+  private void doSchedule() {
+    final Optional<List<Stage>> earliest = selectEarliestSchedulableGroup();
 
-    if (nextScheduleGroupToSchedule.isPresent()) {
-      LOG.info("Scheduling: ScheduleGroup {}", nextScheduleGroupToSchedule.get());
-      final List<Task> tasksToSchedule = nextScheduleGroupToSchedule.get().stream()
-          .flatMap(stage -> getSchedulableTasks(stage).stream())
+    if (earliest.isPresent()) {
+      // Get schedulable tasks.
+      final List<Task> tasksToSchedule = earliest.get().stream()
+          .flatMap(stage -> selectSchedulableTasks(stage).stream())
           .collect(Collectors.toList());
+
+      LOG.info("Attempting to schedule {} in the same ScheduleGroup",
+          tasksToSchedule.stream().map(Task::getTaskId).collect(Collectors.toList()));
+
+      // Set the pointer to the schedulable tasks.
       pendingTaskCollectionPointer.setToOverwrite(tasksToSchedule);
+
+      // Notify the runner that a new collection is available.
       schedulerRunner.onNewPendingTaskCollectionAvailable();
     } else {
-      LOG.info("Skipping this round as the next schedulable stages have already been scheduled.");
+      LOG.info("Skipping this round as no ScheduleGroup is schedulable.");
     }
   }
 
-  /**
-   * Selects the next stage to schedule.
-   * It takes the referenceScheduleGroup as a reference point to begin looking for the stages to execute:
-   *
-   * a) returns the failed_recoverable stage(s) of the earliest schedule group, if it(they) exists.
-   * b) returns an empty optional if there are no schedulable stages at the moment.
-   *    - if the current schedule group is still executing
-   *    - if an ancestor schedule group is still executing
-   * c) returns the next set of schedulable stages (if the current schedule group has completed execution)
-   *
-   * @param referenceScheduleGroup
-   *      the index of the schedule group that is executing/has executed when this method is called.
-   * @return an optional of the (possibly empty) next schedulable stage
-   */
-  private Optional<List<Stage>> selectNextScheduleGroupToSchedule(final int referenceScheduleGroup) {
-    // Recursively check the previous schedule group.
-    if (referenceScheduleGroup > initialScheduleGroup) {
-      final Optional<List<Stage>> ancestorStagesFromAScheduleGroup =
-          selectNextScheduleGroupToSchedule(referenceScheduleGroup - 1);
-      if (ancestorStagesFromAScheduleGroup.isPresent()) {
-        // Nothing to schedule from the previous schedule group.
-        return ancestorStagesFromAScheduleGroup;
-      }
-    }
-
-    // Return the schedulable stage list in reverse-topological order
-    // since the stages that belong to the same schedule group are mutually independent,
-    // or connected by a "push" edge, where scheduling the children stages first is preferred.
-    final List<Stage> reverseTopoStages = physicalPlan.getStageDAG().getTopologicalSort();
-    Collections.reverse(reverseTopoStages);
-
-    // All previous schedule groups are complete, we need to check for the current schedule group.
-    final List<Stage> currentScheduleGroup = reverseTopoStages
-        .stream()
-        .filter(stage -> stage.getScheduleGroup() == referenceScheduleGroup)
-        .collect(Collectors.toList());
-    final boolean allStagesOfThisGroupComplete = currentScheduleGroup
-        .stream()
-        .map(Stage::getId)
-        .map(jobStateManager::getStageState)
-        .allMatch(state -> state.equals(StageState.State.COMPLETE));
-
-    if (!allStagesOfThisGroupComplete) {
-      LOG.info("There are remaining stages in the current schedule group, {}", referenceScheduleGroup);
-      final List<Stage> stagesToSchedule = currentScheduleGroup
-          .stream()
-          .filter(stage -> {
-            final StageState.State stageState = jobStateManager.getStageState(stage.getId());
-            return stageState.equals(StageState.State.FAILED_RECOVERABLE)
-                || stageState.equals(StageState.State.READY);
-          })
-          .collect(Collectors.toList());
-      return (stagesToSchedule.isEmpty())
-          ? Optional.empty()
-          : Optional.of(stagesToSchedule);
-    } else {
-      // By the time the control flow has reached here,
-      // we are ready to move onto the next ScheduleGroup
-      final List<Stage> stagesToSchedule = reverseTopoStages
-          .stream()
-          .filter(stage -> {
-            if (stage.getScheduleGroup() == referenceScheduleGroup + 1) {
-              final String stageId = stage.getId();
-              return jobStateManager.getStageState(stageId) != StageState.State.EXECUTING
-                  && jobStateManager.getStageState(stageId) != StageState.State.COMPLETE;
-            }
-            return false;
-          })
-          .collect(Collectors.toList());
-
-      if (stagesToSchedule.isEmpty()) {
-        LOG.debug("ScheduleGroup {}: already executing/complete!, so we skip this", referenceScheduleGroup + 1);
-        return Optional.empty();
-      }
-
-      return Optional.of(stagesToSchedule);
-    }
+  private Optional<List<Stage>> selectEarliestSchedulableGroup() {
+    return sortedScheduleGroups.stream()
+        .filter(scheduleGroup -> scheduleGroup.stream()
+            .map(Stage::getId)
+            .map(jobStateManager::getStageState)
+            .anyMatch(state -> state.equals(StageState.State.INCOMPLETE))) // any incomplete stage in the group
+        .findFirst(); // selects the one with the smallest scheduling group index.
   }
 
-  /**
-   * @param stageToSchedule the stage to schedule.
-   */
-  private List<Task> getSchedulableTasks(final Stage stageToSchedule) {
+  private List<Task> selectSchedulableTasks(final Stage stageToSchedule) {
     final List<StageEdge> stageIncomingEdges =
         physicalPlan.getStageDAG().getIncomingEdgesOf(stageToSchedule.getId());
     final List<StageEdge> stageOutgoingEdges =
@@ -333,39 +305,35 @@ private void scheduleNextScheduleGroup(final int referenceIndex) {
 
     final List<String> taskIdsToSchedule = new LinkedList<>();
     for (final String taskId : stageToSchedule.getTaskIds()) {
-      // this happens when the belonging stage's other tasks have failed recoverable,
-      // but this task's results are safe.
       final TaskState.State taskState = jobStateManager.getTaskState(taskId);
 
       switch (taskState) {
+        // Don't schedule these.
         case COMPLETE:
         case EXECUTING:
-          LOG.info("Skipping {} because its outputs are safe!", taskId);
+        case ON_HOLD:
           break;
-        case FAILED_RECOVERABLE:
-          jobStateManager.onTaskStateChanged(taskId, READY);
+
+        // These are schedulable.
+        case SHOULD_RETRY:
+          jobStateManager.onTaskStateChanged(taskId, TaskState.State.READY);
         case READY:
           taskIdsToSchedule.add(taskId);
           break;
-        case ON_HOLD:
-          // Do nothing
-          break;
+
+        // This shouldn't happen.
         default:
-          throw new SchedulingException(new Throwable("Detected a FAILED_UNRECOVERABLE Task"));
+          throw new SchedulingException(new Throwable("Detected a FAILED Task"));
       }
     }
 
-    LOG.info("Scheduling Stage {}", stageToSchedule.getId());
-
-    // each readable and source task will be bounded in executor.
+    // Create and return tasks.
     final List<Map<String, Readable>> vertexIdToReadables = stageToSchedule.getVertexIdToReadables();
-
     final List<Task> tasks = new ArrayList<>(taskIdsToSchedule.size());
     taskIdsToSchedule.forEach(taskId -> {
-      blockManagerMaster.onProducerTaskScheduled(taskId);
+      blockManagerMaster.onProducerTaskScheduled(taskId); // Notify the block manager early for push edges.
       final int taskIdx = RuntimeIdGenerator.getIndexFromTaskId(taskId);
       final int attemptIdx = jobStateManager.getTaskAttempt(taskId);
-
       tasks.add(new Task(
           physicalPlan.getId(),
           taskId,
@@ -379,38 +347,7 @@ private void scheduleNextScheduleGroup(final int referenceIndex) {
     return tasks;
   }
 
-  /**
-   * @param taskId id of the task
-   * @return the IR dag
-   */
-  private DAG<IRVertex, RuntimeEdge<IRVertex>> getVertexDagById(final String taskId) {
-    for (final Stage stage : physicalPlan.getStageDAG().getVertices()) {
-      if (stage.getId().equals(RuntimeIdGenerator.getStageIdFromTaskId(taskId))) {
-        return stage.getIRDAG();
-      }
-    }
-    throw new RuntimeException(new Throwable("This taskId does not exist in the plan"));
-  }
-
-  private Stage getStageById(final String stageId) {
-    for (final Stage stage : physicalPlan.getStageDAG().getVertices()) {
-      if (stage.getId().equals(stageId)) {
-        return stage;
-      }
-    }
-    throw new RuntimeException(new Throwable("This taskId does not exist in the plan"));
-  }
-
-  /**
-   * Action after task execution has been completed, not after it has been put on hold.
-   *
-   * @param executorId  the ID of the executor.
-   * @param taskId the ID pf the task completed.
-   */
-  private void onTaskExecutionComplete(final String executorId,
-                                       final String taskId) {
-    onTaskExecutionComplete(executorId, taskId, false);
-  }
+  ////////////////////////////////////////////////////////////////////// Task state change handlers
 
   /**
    * Action after task execution has been completed.
@@ -420,7 +357,7 @@ private void onTaskExecutionComplete(final String executorId,
    */
   private void onTaskExecutionComplete(final String executorId,
                                        final String taskId,
-                                       final Boolean isOnHoldToComplete) {
+                                       final boolean isOnHoldToComplete) {
     LOG.debug("{} completed in {}", new Object[]{taskId, executorId});
     if (!isOnHoldToComplete) {
       executorRegistry.updateExecutor(executorId, (executor, state) -> {
@@ -428,15 +365,6 @@ private void onTaskExecutionComplete(final String executorId,
         return Pair.of(executor, state);
       });
     }
-
-    final String stageIdForTaskUponCompletion = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
-    if (jobStateManager.getStageState(stageIdForTaskUponCompletion).equals(StageState.State.COMPLETE)) {
-      // if the stage this task belongs to is complete,
-      if (!jobStateManager.isJobDone()) {
-        scheduleNextScheduleGroup(getScheduleGroupOfStage(stageIdForTaskUponCompletion));
-      }
-    }
-    schedulerRunner.onAnExecutorAvailable();
   }
 
   /**
@@ -466,7 +394,7 @@ private void onTaskExecutionOnHold(final String executorId,
               .filter(irVertex -> irVertex instanceof MetricCollectionBarrierVertex)
               .distinct()
               .map(irVertex -> (MetricCollectionBarrierVertex) irVertex) // convert types
-              .findFirst().orElseThrow(() -> new RuntimeException(ON_HOLD.name() // get it
+              .findFirst().orElseThrow(() -> new RuntimeException(TaskState.State.ON_HOLD.name() // get it
               + " called with failed task ids by some other task than "
               + MetricCollectionBarrierVertex.class.getSimpleName()));
       // and we will use this vertex to perform metric collection and dynamic optimization.
@@ -476,7 +404,6 @@ private void onTaskExecutionOnHold(final String executorId,
     } else {
       onTaskExecutionComplete(executorId, taskId, true);
     }
-    schedulerRunner.onAnExecutorAvailable();
   }
 
   /**
@@ -487,33 +414,75 @@ private void onTaskExecutionOnHold(final String executorId,
    */
   private void onTaskExecutionFailedRecoverable(final String executorId,
                                                 final String taskId,
-                                                final TaskState.RecoverableFailureCause failureCause) {
+                                                final TaskState.RecoverableTaskFailureCause failureCause) {
     LOG.info("{} failed in {} by {}", taskId, executorId, failureCause);
     executorRegistry.updateExecutor(executorId, (executor, state) -> {
       executor.onTaskExecutionFailed(taskId);
       return Pair.of(executor, state);
     });
 
-    final String stageId = RuntimeIdGenerator.getStageIdFromTaskId(taskId);
-
     switch (failureCause) {
       // Previous task must be re-executed, and incomplete tasks of the belonging stage must be rescheduled.
       case INPUT_READ_FAILURE:
-        // TODO #50: Carefully retry tasks in the scheduler
+        // TODO #54: Handle remote data fetch failures
       case OUTPUT_WRITE_FAILURE:
         blockManagerMaster.onProducerTaskFailed(taskId);
-        scheduleNextScheduleGroup(getScheduleGroupOfStage(stageId));
-        break;
-      case CONTAINER_FAILURE:
-        LOG.info("Only the failed task will be retried.");
         break;
       default:
         throw new UnknownFailureCauseException(new Throwable("Unknown cause: " + failureCause));
     }
-    schedulerRunner.onAnExecutorAvailable();
   }
 
-  private int getScheduleGroupOfStage(final String stageId) {
-    return physicalPlan.getStageDAG().getVertexById(stageId).getScheduleGroup();
+  ////////////////////////////////////////////////////////////////////// Helper methods
+
+  private Set<String> recursivelyGetParentTasksForLostBlocks(final Set<String> children) {
+    if (children.isEmpty()) {
+      return Collections.emptySet();
+    }
+
+    final Set<String> selectedParentTasks = children.stream()
+        .flatMap(child -> getParentTasks(child).stream())
+        .filter(parent -> blockManagerMaster.getIdsOfBlocksProducedBy(parent).stream()
+            .map(blockManagerMaster::getBlockState)
+            .anyMatch(blockState -> blockState.equals(BlockState.State.NOT_AVAILABLE)) // If a block is missing
+        )
+        .collect(Collectors.toSet());
+
+    // Recursive call
+    return Sets.union(selectedParentTasks, recursivelyGetParentTasksForLostBlocks(selectedParentTasks));
+  }
+
+  private Set<String> getParentTasks(final String childTaskId) {
+    final String stageIdOfChildTask = RuntimeIdGenerator.getStageIdFromTaskId(childTaskId);
+    return physicalPlan.getStageDAG().getIncomingEdgesOf(stageIdOfChildTask)
+        .stream()
+        .flatMap(inStageEdge -> {
+          final List<String> tasksOfParentStage = inStageEdge.getSrc().getTaskIds();
+          switch (inStageEdge.getDataCommunicationPattern()) {
+            case Shuffle:
+            case BroadCast:
+              // All of the parent stage's tasks are parents
+              return tasksOfParentStage.stream();
+            case OneToOne:
+              // Only one of the parent stage's tasks is a parent
+              return Stream.of(tasksOfParentStage.get(RuntimeIdGenerator.getIndexFromTaskId(childTaskId)));
+            default:
+              throw new IllegalStateException(inStageEdge.toString());
+          }
+        })
+        .collect(Collectors.toSet());
+  }
+
+  /**
+   * @param taskId id of the task
+   * @return the IR dag
+   */
+  private DAG<IRVertex, RuntimeEdge<IRVertex>> getVertexDagById(final String taskId) {
+    for (final Stage stage : physicalPlan.getStageDAG().getVertices()) {
+      if (stage.getId().equals(RuntimeIdGenerator.getStageIdFromTaskId(taskId))) {
+        return stage.getIRDAG();
+      }
+    }
+    throw new RuntimeException("This taskId does not exist in the plan");
   }
 }
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
index 12114a75..28b52db7 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/Scheduler.java
@@ -79,7 +79,7 @@ void onTaskStateReportFromExecutor(String executorId,
                                      int attemptIdx,
                                      TaskState.State newState,
                                      @Nullable String taskPutOnHold,
-                                     TaskState.RecoverableFailureCause failureCause);
+                                     TaskState.RecoverableTaskFailureCause failureCause);
 
   /**
    * To be called when a job should be terminated.
diff --git a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
index 42e9a2c4..caf0d40a 100644
--- a/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
+++ b/runtime/master/src/main/java/edu/snu/nemo/runtime/master/scheduler/SchedulerRunner.java
@@ -148,9 +148,9 @@ void doScheduleTaskList() {
   }
 
   /**
-   * Signals to the condition on executor availability.
+   * Signals to the condition on executor slot availability.
    */
-  void onAnExecutorAvailable() {
+  void onExecutorSlotAvailable() {
     schedulingIteration.signal();
   }
 
@@ -164,24 +164,15 @@ void onNewPendingTaskCollectionAvailable() {
   /**
    * Run the scheduler thread.
    */
-  void runSchedulerThread() {
+  void run(final JobStateManager jobStateManager) {
     if (!isTerminated && !isSchedulerRunning) {
+      jobStateManagers.put(jobStateManager.getJobId(), jobStateManager);
       schedulerThread.execute(new SchedulerThread());
       schedulerThread.shutdown();
       isSchedulerRunning = true;
     }
   }
 
-  /**
-   * Begin scheduling a job.
-   * @param jobStateManager the corresponding {@link JobStateManager}
-   */
-  void scheduleJob(final JobStateManager jobStateManager) {
-    if (!isTerminated) {
-      jobStateManagers.put(jobStateManager.getJobId(), jobStateManager);
-    } // else ignore new incoming jobs when terminated.
-  }
-
   void terminate() {
     isTerminated = true;
     schedulingIteration.signal();
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
index 4aa2b0ac..68608938 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/JobStateManagerTest.java
@@ -76,7 +76,7 @@ public void testPhysicalPlanStateChanges() throws Exception {
     final PhysicalPlan physicalPlan =
         TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
     final JobStateManager jobStateManager =
-        new JobStateManager(physicalPlan, blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+        new JobStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
 
     assertEquals(jobStateManager.getJobId(), "TestPlan");
 
@@ -108,7 +108,7 @@ public void testWaitUntilFinish() throws Exception {
     final PhysicalPlan physicalPlan =
         TestPlanGenerator.generatePhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined, false);
     final JobStateManager jobStateManager =
-        new JobStateManager(physicalPlan, blockManagerMaster, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+        new JobStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
 
     assertFalse(jobStateManager.isJobDone());
 
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
index 8bfe43d2..e4491b67 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/BatchSingleJobSchedulerTest.java
@@ -148,23 +148,16 @@ public void testPush() throws Exception {
   }
 
   private void scheduleAndCheckJobTermination(final PhysicalPlan plan) throws InjectionException {
-    final JobStateManager jobStateManager = new JobStateManager(plan, blockManagerMaster, metricMessageHandler, 1);
+    final JobStateManager jobStateManager = new JobStateManager(plan, metricMessageHandler, 1);
     scheduler.scheduleJob(plan, jobStateManager);
 
-    // For each ScheduleGroup, test:
-    // a) all stages in the ScheduleGroup enters the executing state
-    // b) the stages of the next ScheduleGroup are scheduled after the stages of each ScheduleGroup are made "complete".
+    // For each ScheduleGroup, test if the tasks of the next ScheduleGroup are scheduled
+    // after the stages of each ScheduleGroup are made "complete".
     for (int i = 0; i < getNumScheduleGroups(plan.getStageDAG()); i++) {
       final int scheduleGroupIdx = i;
       final List<Stage> stages = filterStagesWithAScheduleGroup(plan.getStageDAG(), scheduleGroupIdx);
 
       LOG.debug("Checking that all stages of ScheduleGroup {} enter the executing state", scheduleGroupIdx);
-      stages.forEach(stage -> {
-        while (jobStateManager.getStageState(stage.getId()) != StageState.State.EXECUTING) {
-
-        }
-      });
-
       stages.forEach(stage -> {
         SchedulerTestUtil.completeStage(
             jobStateManager, scheduler, executorRegistry, stage, SCHEDULE_ATTEMPT_INDEX);
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
index 387c6b90..815dff8f 100644
--- a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/SchedulerTestUtil.java
@@ -45,7 +45,7 @@ static void completeStage(final JobStateManager jobStateManager,
       if (StageState.State.COMPLETE == stageState) {
         // Stage has completed, so we break out of the loop.
         break;
-      } else if (StageState.State.EXECUTING == stageState) {
+      } else if (StageState.State.INCOMPLETE == stageState) {
         stage.getTaskIds().forEach(taskId -> {
           final TaskState.State taskState = jobStateManager.getTaskState(taskId);
           if (TaskState.State.EXECUTING == taskState) {
@@ -57,8 +57,6 @@ static void completeStage(final JobStateManager jobStateManager,
             throw new IllegalStateException(taskState.toString());
           }
         });
-      } else if (StageState.State.READY == stageState) {
-        // Skip and retry in the next loop.
       } else {
         throw new IllegalStateException(stageState.toString());
       }
@@ -79,7 +77,7 @@ static void sendTaskStateEventToScheduler(final Scheduler scheduler,
                                             final String taskId,
                                             final TaskState.State newState,
                                             final int attemptIdx,
-                                            final TaskState.RecoverableFailureCause cause) {
+                                            final TaskState.RecoverableTaskFailureCause cause) {
     final ExecutorRepresenter scheduledExecutor;
     while (true) {
       final Optional<ExecutorRepresenter> optional = executorRegistry.findExecutorForTask(taskId);
@@ -91,4 +89,12 @@ static void sendTaskStateEventToScheduler(final Scheduler scheduler,
     scheduler.onTaskStateReportFromExecutor(scheduledExecutor.getExecutorId(), taskId, attemptIdx,
         newState, null, cause);
   }
+
+  static void sendTaskStateEventToScheduler(final Scheduler scheduler,
+                                            final ExecutorRegistry executorRegistry,
+                                            final String taskId,
+                                            final TaskState.State newState,
+                                            final int attemptIdx) {
+    sendTaskStateEventToScheduler(scheduler, executorRegistry, taskId, newState, attemptIdx, null);
+  }
 }
diff --git a/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRestartTest.java b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRestartTest.java
new file mode 100644
index 00000000..7064a23e
--- /dev/null
+++ b/runtime/master/src/test/java/edu/snu/nemo/runtime/master/scheduler/TaskRestartTest.java
@@ -0,0 +1,225 @@
+/*
+ * Copyright (C) 2018 Seoul National University
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package edu.snu.nemo.runtime.master.scheduler;
+
+import edu.snu.nemo.common.eventhandler.PubSubEventHandlerWrapper;
+import edu.snu.nemo.common.ir.vertex.executionproperty.ExecutorPlacementProperty;
+import edu.snu.nemo.runtime.common.comm.ControlMessage;
+import edu.snu.nemo.runtime.common.message.MessageSender;
+import edu.snu.nemo.runtime.common.plan.PhysicalPlan;
+import edu.snu.nemo.runtime.common.state.JobState;
+import edu.snu.nemo.runtime.common.state.TaskState;
+import edu.snu.nemo.runtime.master.BlockManagerMaster;
+import edu.snu.nemo.runtime.master.JobStateManager;
+import edu.snu.nemo.runtime.master.MetricMessageHandler;
+import edu.snu.nemo.runtime.master.eventhandler.UpdatePhysicalPlanEventHandler;
+import edu.snu.nemo.runtime.master.resource.ExecutorRepresenter;
+import edu.snu.nemo.runtime.master.resource.ResourceSpecification;
+import edu.snu.nemo.runtime.plangenerator.TestPlanGenerator;
+import org.apache.reef.driver.context.ActiveContext;
+import org.apache.reef.tang.Injector;
+import org.apache.reef.tang.Tang;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TestName;
+import org.junit.runner.RunWith;
+import org.mockito.Mockito;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.*;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests fault tolerance.
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest({BlockManagerMaster.class, SchedulerRunner.class, SchedulingConstraintRegistry.class,
+    PubSubEventHandlerWrapper.class, UpdatePhysicalPlanEventHandler.class, MetricMessageHandler.class})
+public final class TaskRestartTest {
+  @Rule public TestName testName = new TestName();
+
+  private static final Logger LOG = LoggerFactory.getLogger(TaskRestartTest.class.getName());
+  private static final AtomicInteger ID_OFFSET = new AtomicInteger(1);
+
+  private Random random;
+  private Scheduler scheduler;
+  private ExecutorRegistry executorRegistry;
+  private JobStateManager jobStateManager;
+
+  private static final int MAX_SCHEDULE_ATTEMPT = Integer.MAX_VALUE;
+
+  @Before
+  public void setUp() throws Exception {
+    // To understand which part of the log belongs to which test
+    LOG.info("===== Testing {} =====", testName.getMethodName());
+    final Injector injector = Tang.Factory.getTang().newInjector();
+
+    // Get random
+    random = new Random(0); // Fixed seed for reproducing test results.
+
+    // Get executorRegistry
+    executorRegistry = injector.getInstance(ExecutorRegistry.class);
+
+    // Get scheduler
+    final PubSubEventHandlerWrapper pubSubEventHandler = mock(PubSubEventHandlerWrapper.class);
+    final UpdatePhysicalPlanEventHandler updatePhysicalPlanEventHandler = mock(UpdatePhysicalPlanEventHandler.class);
+    final SchedulingConstraintRegistry constraintRegistry = mock(SchedulingConstraintRegistry.class);
+    final SchedulingPolicy schedulingPolicy = injector.getInstance(MinOccupancyFirstSchedulingPolicy.class);
+    final PendingTaskCollectionPointer pendingTaskCollectionPointer = new PendingTaskCollectionPointer();
+    final SchedulerRunner schedulerRunner = new SchedulerRunner(
+        constraintRegistry, schedulingPolicy, pendingTaskCollectionPointer, executorRegistry);
+    final BlockManagerMaster blockManagerMaster = mock(BlockManagerMaster.class);
+    scheduler =  new BatchSingleJobScheduler(schedulerRunner, pendingTaskCollectionPointer, blockManagerMaster,
+        pubSubEventHandler, updatePhysicalPlanEventHandler, executorRegistry);
+
+    // Get JobStateManager
+    jobStateManager = runPhysicalPlan(TestPlanGenerator.PlanType.TwoVerticesJoined);
+  }
+
+  @Test(timeout=7000)
+  public void testExecutorRemoved() throws Exception {
+    // Until the job finishes, events happen
+    while (!jobStateManager.isJobDone()) {
+      // 50% chance remove, 50% chance add, 80% chance task completed
+      executorRemoved(0.5);
+      executorAdded(0.5);
+      taskCompleted(0.8);
+
+      // 10ms sleep
+      Thread.sleep(10);
+    }
+
+    // Job should COMPLETE
+    assertEquals(JobState.State.COMPLETE, jobStateManager.getJobState());
+    assertTrue(jobStateManager.isJobDone());
+  }
+
+  @Test(timeout=7000)
+  public void testTaskOutputWriteFailure() throws Exception {
+    // Three executors are used
+    executorAdded(1.0);
+    executorAdded(1.0);
+    executorAdded(1.0);
+
+    // Until the job finishes, events happen
+    while (!jobStateManager.isJobDone()) {
+      // 50% chance task completed
+      // 50% chance task output write failed
+      taskCompleted(0.5);
+      taskOutputWriteFailed(0.5);
+
+      // 10ms sleep
+      Thread.sleep(10);
+    }
+
+    // Job should COMPLETE
+    assertEquals(JobState.State.COMPLETE, jobStateManager.getJobState());
+    assertTrue(jobStateManager.isJobDone());
+  }
+
+  ////////////////////////////////////////////////////////////////// Events
+
+  private void executorAdded(final double chance) {
+    if (random.nextDouble() > chance) {
+      return;
+    }
+
+    final MessageSender<ControlMessage.Message> mockMsgSender = mock(MessageSender.class);
+    final ActiveContext activeContext = mock(ActiveContext.class);
+    Mockito.doThrow(new RuntimeException()).when(activeContext).close();
+    final ExecutorService serExecutorService = Executors.newSingleThreadExecutor();
+    final ResourceSpecification computeSpec = new ResourceSpecification(ExecutorPlacementProperty.COMPUTE, 2, 0);
+    final ExecutorRepresenter executor = new ExecutorRepresenter("EXECUTOR" + ID_OFFSET.getAndIncrement(),
+        computeSpec, mockMsgSender, activeContext, serExecutorService, "NODE" + ID_OFFSET.getAndIncrement());
+    scheduler.onExecutorAdded(executor);
+  }
+
+  private void executorRemoved(final double chance) {
+    if (random.nextDouble() > chance) {
+      return;
+    }
+
+    executorRegistry.viewExecutors(executors -> {
+      if (executors.isEmpty()) {
+        return;
+      }
+
+      final List<ExecutorRepresenter> executorList = new ArrayList<>(executors);
+      final int randomIndex = random.nextInt(executorList.size());
+
+      // Because synchronized blocks are reentrant and there's no additional operation after this point,
+      // we can scheduler.onExecutorRemoved() while being inside executorRegistry.viewExecutors()
+      scheduler.onExecutorRemoved(executorList.get(randomIndex).getExecutorId());
+    });
+  }
+
+  private void taskCompleted(final double chance) {
+    if (random.nextDouble() > chance) {
+      return;
+    }
+
+    final List<String> executingTasks = getTasksInState(jobStateManager, TaskState.State.EXECUTING);
+    if (!executingTasks.isEmpty()) {
+      final int randomIndex = random.nextInt(executingTasks.size());
+      final String selectedTask = executingTasks.get(randomIndex);
+      SchedulerTestUtil.sendTaskStateEventToScheduler(scheduler, executorRegistry, selectedTask,
+          TaskState.State.COMPLETE, jobStateManager.getTaskAttempt(selectedTask));
+    }
+  }
+
+  private void taskOutputWriteFailed(final double chance) {
+    if (random.nextDouble() > chance) {
+      return;
+    }
+
+    final List<String> executingTasks = getTasksInState(jobStateManager, TaskState.State.EXECUTING);
+    if (!executingTasks.isEmpty()) {
+      final int randomIndex = random.nextInt(executingTasks.size());
+      final String selectedTask = executingTasks.get(randomIndex);
+      SchedulerTestUtil.sendTaskStateEventToScheduler(scheduler, executorRegistry, selectedTask,
+          TaskState.State.SHOULD_RETRY, jobStateManager.getTaskAttempt(selectedTask),
+          TaskState.RecoverableTaskFailureCause.OUTPUT_WRITE_FAILURE);
+    }
+  }
+
+  ////////////////////////////////////////////////////////////////// Helper methods
+
+  private List<String> getTasksInState(final JobStateManager jobStateManager, final TaskState.State state) {
+    return jobStateManager.getAllTaskStates().entrySet().stream()
+        .filter(entry -> entry.getValue().getStateMachine().getCurrentState().equals(state))
+        .map(Map.Entry::getKey)
+        .collect(Collectors.toList());
+  }
+
+  private JobStateManager runPhysicalPlan(final TestPlanGenerator.PlanType planType) throws Exception {
+    final MetricMessageHandler metricMessageHandler = mock(MetricMessageHandler.class);
+    final PhysicalPlan plan = TestPlanGenerator.generatePhysicalPlan(planType, false);
+    final JobStateManager jobStateManager = new JobStateManager(plan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+    scheduler.scheduleJob(plan, jobStateManager);
+    return jobStateManager;
+  }
+}
diff --git a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
index c962ad33..adce9e66 100644
--- a/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
+++ b/tests/src/test/java/edu/snu/nemo/tests/client/ClientEndpointTest.java
@@ -81,9 +81,8 @@ public void testState() throws Exception {
         new LocalMessageEnvironment(MessageEnvironment.MASTER_COMMUNICATION_ID, messageDispatcher);
     final Injector injector = Tang.Factory.getTang().newInjector();
     injector.bindVolatileInstance(MessageEnvironment.class, messageEnvironment);
-    final BlockManagerMaster pmm = injector.getInstance(BlockManagerMaster.class);
     final JobStateManager jobStateManager =
-        new JobStateManager(physicalPlan, pmm, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
+        new JobStateManager(physicalPlan, metricMessageHandler, MAX_SCHEDULE_ATTEMPT);
 
     final DriverEndpoint driverEndpoint = new DriverEndpoint(jobStateManager, clientEndpoint);
 


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services