You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by gu...@apache.org on 2021/02/26 06:21:30 UTC

[flink] branch master updated (1939e9b -> e084dab)

This is an automated email from the ASF dual-hosted git repository.

guoweima pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 1939e9b  [FLINK-21114][python] Add ReducingState and corresponding StateDescriptor for Python DataStream API
     new 93c0677  [hotfix] Fix CheckpointCoordinatorTest to be consistent with the current implementation
     new e084dab  [FLINK-21067][runtime][checkpoint] Modify the logic of computing which tasks to trigger/ack/commit to support finished tasks

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../runtime/checkpoint/CheckpointCoordinator.java  |  117 +-
 .../flink/runtime/checkpoint/CheckpointPlan.java   |   31 +-
 .../checkpoint/CheckpointPlanCalculator.java       |  109 +-
 .../CheckpointPlanCalculatorContext.java           |   25 +-
 .../DefaultCheckpointPlanCalculator.java           |  342 ++++++
 .../ExecutionAttemptMappingProvider.java           |    8 +-
 .../runtime/checkpoint/PendingCheckpoint.java      |    8 +-
 .../runtime/checkpoint/SubtaskStateStats.java      |    4 +
 .../flink/runtime/executiongraph/Execution.java    |    2 +-
 .../runtime/executiongraph/ExecutionGraph.java     |   40 +-
 ...cutionGraphCheckpointPlanCalculatorContext.java |   35 +-
 .../CheckpointCoordinatorFailureTest.java          |   22 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |   96 +-
 .../CheckpointCoordinatorRestoringTest.java        |  265 ++--
 .../checkpoint/CheckpointCoordinatorTest.java      | 1293 +++++++++++---------
 .../CheckpointCoordinatorTestingUtils.java         |  524 ++++----
 .../CheckpointCoordinatorTriggeringTest.java       |  256 ++--
 .../checkpoint/CheckpointStateRestoreTest.java     |  110 +-
 .../checkpoint/CheckpointStatsTrackerTest.java     |   67 +-
 .../DefaultCheckpointPlanCalculatorTest.java       |  429 +++++++
 .../FailoverStrategyCheckpointCoordinatorTest.java |   53 +-
 .../runtime/checkpoint/PendingCheckpointTest.java  |   19 +-
 22 files changed, 2335 insertions(+), 1520 deletions(-)
 copy flink-metrics/flink-metrics-core/src/main/java/org/apache/flink/metrics/CharacterFilter.java => flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculatorContext.java (62%)
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculator.java
 copy flink-formats/flink-parquet/src/main/java/org/apache/flink/formats/parquet/utils/RowMaterializer.java => flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphCheckpointPlanCalculatorContext.java (52%)
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest.java


[flink] 01/02: [hotfix] Fix CheckpointCoordinatorTest to be consistent with the current implementation

Posted by gu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

guoweima pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 93c06770aa8553b5e4a387591fd4bed2efaf44a2
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Wed Feb 3 14:59:50 2021 +0800

    [hotfix] Fix CheckpointCoordinatorTest to be consistent with the current implementation
    
    Before, all three of these tests were exactly the same code:
     - testCheckpointAbortsIfTriggerTasksAreNotExecuted
     - testCheckpointAbortsIfTriggerTasksAreFinished
     - testCheckpointAbortsIfAckTasksAreNotExecuted
    This removes testCheckpointAbortsIfAckTasksAreNotExecuted and changes the first
    two tests to test different states. The first test has tasks in CREATED state
    and the second test has them in FINISHED state, to reflect the name of the
    test.
---
 .../checkpoint/CheckpointCoordinatorTest.java      | 44 ++++++----------------
 1 file changed, 11 insertions(+), 33 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index dcf065a..03dbd55 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -260,7 +260,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
     public void testScheduleTriggerRequestDuringShutdown() throws Exception {
         ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
         CheckpointCoordinator coordinator =
-                getCheckpointCoordinator(new ScheduledExecutorServiceAdapter(executor));
+                getCheckpointCoordinator(
+                        new ScheduledExecutorServiceAdapter(executor), ExecutionState.RUNNING);
         coordinator.shutdown();
         executor.shutdownNow();
         coordinator.scheduleTriggerRequest(); // shouldn't fail
@@ -319,7 +320,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
         try {
 
             // set up the coordinator and validate the initial state
-            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator();
+            CheckpointCoordinator checkpointCoordinator =
+                    getCheckpointCoordinator(ExecutionState.CREATED);
 
             // nothing should be happening
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -345,33 +347,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testCheckpointAbortsIfTriggerTasksAreFinished() {
         try {
-            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator();
-
-            // nothing should be happening
-            assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
-            assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
-
-            // trigger the first checkpoint. this should not succeed
-            final CompletableFuture<CompletedCheckpoint> checkpointFuture =
-                    checkpointCoordinator.triggerCheckpoint(false);
-            manuallyTriggeredScheduledExecutor.triggerAll();
-            assertTrue(checkpointFuture.isCompletedExceptionally());
-
-            // still, nothing should be happening
-            assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
-            assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
-
-            checkpointCoordinator.shutdown();
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
-    }
-
-    @Test
-    public void testCheckpointAbortsIfAckTasksAreNotExecuted() {
-        try {
-            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator();
+            CheckpointCoordinator checkpointCoordinator =
+                    getCheckpointCoordinator(ExecutionState.FINISHED);
 
             // nothing should be happening
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -3324,11 +3301,12 @@ public class CheckpointCoordinatorTest extends TestLogger {
                 .build();
     }
 
-    private CheckpointCoordinator getCheckpointCoordinator() {
-        return getCheckpointCoordinator(manuallyTriggeredScheduledExecutor);
+    private CheckpointCoordinator getCheckpointCoordinator(ExecutionState triggerVertexState) {
+        return getCheckpointCoordinator(manuallyTriggeredScheduledExecutor, triggerVertexState);
     }
 
-    private CheckpointCoordinator getCheckpointCoordinator(ScheduledExecutor timer) {
+    private CheckpointCoordinator getCheckpointCoordinator(
+            ScheduledExecutor timer, ExecutionState triggerVertexState) {
         final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
         final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
         ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
@@ -3340,7 +3318,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                         Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID2)),
                         1,
                         1,
-                        ExecutionState.FINISHED);
+                        triggerVertexState);
 
         // create some mock Execution vertices that need to ack the checkpoint
         final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();


[flink] 02/02: [FLINK-21067][runtime][checkpoint] Modify the logic of computing which tasks to trigger/ack/commit to support finished tasks

Posted by gu...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

guoweima pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit e084dabe6c1722dc1ced7d8fddb3c82e7af7b103
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Sun Jan 24 17:41:41 2021 +0800

    [FLINK-21067][runtime][checkpoint] Modify the logic of computing which tasks to trigger/ack/commit to support finished tasks
    
    To support checkpoint after tasks finished, for each
    checkpoint we would like to trigger the new "root" tasks,
    and wait / commit for all the running tasks. Thus this PR
    modifies the logic of identifying the tasks to trigger / wait
    / commit.
    
    This closes #14740
---
 .../runtime/checkpoint/CheckpointCoordinator.java  |  117 +-
 .../flink/runtime/checkpoint/CheckpointPlan.java   |   31 +-
 .../checkpoint/CheckpointPlanCalculator.java       |  109 +-
 .../CheckpointPlanCalculatorContext.java           |   42 +
 .../DefaultCheckpointPlanCalculator.java           |  342 ++++++
 .../ExecutionAttemptMappingProvider.java           |    8 +-
 .../runtime/checkpoint/PendingCheckpoint.java      |    8 +-
 .../runtime/checkpoint/SubtaskStateStats.java      |    4 +
 .../flink/runtime/executiongraph/Execution.java    |    2 +-
 .../runtime/executiongraph/ExecutionGraph.java     |   40 +-
 ...cutionGraphCheckpointPlanCalculatorContext.java |   48 +
 .../CheckpointCoordinatorFailureTest.java          |   22 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |   96 +-
 .../CheckpointCoordinatorRestoringTest.java        |  265 ++--
 .../checkpoint/CheckpointCoordinatorTest.java      | 1287 +++++++++++---------
 .../CheckpointCoordinatorTestingUtils.java         |  524 ++++----
 .../CheckpointCoordinatorTriggeringTest.java       |  256 ++--
 .../checkpoint/CheckpointStateRestoreTest.java     |  110 +-
 .../checkpoint/CheckpointStatsTrackerTest.java     |   67 +-
 .../DefaultCheckpointPlanCalculatorTest.java       |  429 +++++++
 .../FailoverStrategyCheckpointCoordinatorTest.java |   53 +-
 .../runtime/checkpoint/PendingCheckpointTest.java  |   19 +-
 22 files changed, 2401 insertions(+), 1478 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index cf9dc47..38ddee5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.checkpoint.CheckpointType.PostCheckpointAction;
 import org.apache.flink.runtime.checkpoint.hooks.MasterHooks;
 import org.apache.flink.runtime.concurrent.FutureUtils;
@@ -78,6 +79,7 @@ import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Predicate;
+import java.util.stream.Stream;
 
 import static java.util.stream.Collectors.toMap;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
@@ -521,25 +523,41 @@ public class CheckpointCoordinator {
                 preCheckGlobalState(request.isPeriodic);
             }
 
-            CheckpointPlan checkpointPlan = checkpointPlanCalculator.calculateCheckpointPlan();
-
             // we will actually trigger this checkpoint!
             Preconditions.checkState(!isTriggering);
             isTriggering = true;
 
             final long timestamp = System.currentTimeMillis();
+
+            CompletableFuture<CheckpointPlan> checkpointPlanFuture =
+                    checkpointPlanCalculator.calculateCheckpointPlan();
+
             final CompletableFuture<PendingCheckpoint> pendingCheckpointCompletableFuture =
-                    initializeCheckpoint(request.props, request.externalSavepointLocation)
+                    checkpointPlanFuture
                             .thenApplyAsync(
-                                    (checkpointIdAndStorageLocation) ->
+                                    plan -> {
+                                        try {
+                                            CheckpointIdAndStorageLocation
+                                                    checkpointIdAndStorageLocation =
+                                                            initializeCheckpoint(
+                                                                    request.props,
+                                                                    request.externalSavepointLocation);
+                                            return new Tuple2<>(
+                                                    plan, checkpointIdAndStorageLocation);
+                                        } catch (Throwable e) {
+                                            throw new CompletionException(e);
+                                        }
+                                    },
+                                    executor)
+                            .thenApplyAsync(
+                                    (checkpointInfo) ->
                                             createPendingCheckpoint(
                                                     timestamp,
                                                     request.props,
-                                                    checkpointPlan,
+                                                    checkpointInfo.f0,
                                                     request.isPeriodic,
-                                                    checkpointIdAndStorageLocation.checkpointId,
-                                                    checkpointIdAndStorageLocation
-                                                            .checkpointStorageLocation,
+                                                    checkpointInfo.f1.checkpointId,
+                                                    checkpointInfo.f1.checkpointStorageLocation,
                                                     request.getOnCompletionFuture()),
                                     timer);
 
@@ -608,7 +626,9 @@ public class CheckpointCoordinator {
                                                         checkpointId,
                                                         checkpoint.getCheckpointStorageLocation(),
                                                         request.props,
-                                                        checkpointPlan.getTasksToTrigger());
+                                                        checkpoint
+                                                                .getCheckpointPlan()
+                                                                .getTasksToTrigger());
 
                                                 coordinatorsToCheckpoint.forEach(
                                                         (ctx) ->
@@ -645,38 +665,29 @@ public class CheckpointCoordinator {
     }
 
     /**
-     * Initialize the checkpoint trigger asynchronously. It will be executed in io thread due to it
-     * might be time-consuming.
+     * Initialize the checkpoint trigger asynchronously. It will expected to be executed in io
+     * thread due to it might be time-consuming.
      *
      * @param props checkpoint properties
      * @param externalSavepointLocation the external savepoint location, it might be null
-     * @return the future of initialized result, checkpoint id and checkpoint location
+     * @return the initialized result, checkpoint id and checkpoint location
      */
-    private CompletableFuture<CheckpointIdAndStorageLocation> initializeCheckpoint(
-            CheckpointProperties props, @Nullable String externalSavepointLocation) {
+    private CheckpointIdAndStorageLocation initializeCheckpoint(
+            CheckpointProperties props, @Nullable String externalSavepointLocation)
+            throws Exception {
 
-        return CompletableFuture.supplyAsync(
-                () -> {
-                    try {
-                        // this must happen outside the coordinator-wide lock, because it
-                        // communicates
-                        // with external services (in HA mode) and may block for a while.
-                        long checkpointID = checkpointIdCounter.getAndIncrement();
-
-                        CheckpointStorageLocation checkpointStorageLocation =
-                                props.isSavepoint()
-                                        ? checkpointStorageView.initializeLocationForSavepoint(
-                                                checkpointID, externalSavepointLocation)
-                                        : checkpointStorageView.initializeLocationForCheckpoint(
-                                                checkpointID);
-
-                        return new CheckpointIdAndStorageLocation(
-                                checkpointID, checkpointStorageLocation);
-                    } catch (Throwable throwable) {
-                        throw new CompletionException(throwable);
-                    }
-                },
-                executor);
+        // this must happen outside the coordinator-wide lock, because it
+        // communicates
+        // with external services (in HA mode) and may block for a while.
+        long checkpointID = checkpointIdCounter.getAndIncrement();
+
+        CheckpointStorageLocation checkpointStorageLocation =
+                props.isSavepoint()
+                        ? checkpointStorageView.initializeLocationForSavepoint(
+                                checkpointID, externalSavepointLocation)
+                        : checkpointStorageView.initializeLocationForCheckpoint(checkpointID);
+
+        return new CheckpointIdAndStorageLocation(checkpointID, checkpointStorageLocation);
     }
 
     private PendingCheckpoint createPendingCheckpoint(
@@ -794,8 +805,6 @@ public class CheckpointCoordinator {
      * @param checkpointStorageLocation the checkpoint location
      * @param props the checkpoint properties
      * @param tasksToTrigger the executions which should be triggered
-     * @param advanceToEndOfTime Flag indicating if the source should inject a {@code MAX_WATERMARK}
-     *     in the pipeline to fire any registered event-time timers.
      */
     private void snapshotTaskState(
             long timestamp,
@@ -874,6 +883,8 @@ public class CheckpointCoordinator {
                 synchronized (lock) {
                     abortPendingCheckpoint(checkpoint, cause);
                 }
+            } else {
+                LOG.warn("Failed to trigger checkpoint for job {}.)", job, throwable);
             }
         } finally {
             isTriggering = false;
@@ -2048,18 +2059,36 @@ public class CheckpointCoordinator {
             return;
         }
         Map<JobVertexID, Integer> vertices =
-                checkpoint.getCheckpointPlan().getTasksToWaitFor().values().stream()
+                Stream.concat(
+                                checkpoint.getCheckpointPlan().getTasksToWaitFor().stream(),
+                                checkpoint.getCheckpointPlan().getFinishedTasks().stream())
+                        .map(Execution::getVertex)
                         .map(ExecutionVertex::getJobVertex)
                         .distinct()
                         .collect(
                                 toMap(
                                         ExecutionJobVertex::getJobVertexId,
                                         ExecutionJobVertex::getParallelism));
-        statsTracker.reportPendingCheckpoint(
-                checkpoint.getCheckpointID(),
-                checkpoint.getCheckpointTimestamp(),
-                checkpoint.getProps(),
-                vertices);
+
+        PendingCheckpointStats pendingCheckpointStats =
+                statsTracker.reportPendingCheckpoint(
+                        checkpoint.getCheckpointID(),
+                        checkpoint.getCheckpointTimestamp(),
+                        checkpoint.getProps(),
+                        vertices);
+
+        reportFinishedTasks(
+                pendingCheckpointStats, checkpoint.getCheckpointPlan().getFinishedTasks());
+    }
+
+    private void reportFinishedTasks(
+            PendingCheckpointStats pendingCheckpointStats, List<Execution> finishedTasks) {
+        long now = System.currentTimeMillis();
+        finishedTasks.forEach(
+                execution ->
+                        pendingCheckpointStats.reportSubtaskStats(
+                                execution.getVertex().getJobvertexId(),
+                                new SubtaskStateStats(execution.getParallelSubtaskIndex(), now)));
     }
 
     @Nullable
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlan.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlan.java
index 4a0f147..8dd9e6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlan.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlan.java
@@ -19,16 +19,15 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.runtime.executiongraph.Execution;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 
 import java.util.List;
-import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
- * The brief of one checkpoint, indicating which tasks to trigger, waiting for acknowledge or commit
+ * The plan of one checkpoint, indicating which tasks to trigger, waiting for acknowledge or commit
  * for one specific checkpoint.
  */
 class CheckpointPlan {
@@ -37,7 +36,7 @@ class CheckpointPlan {
     private final List<Execution> tasksToTrigger;
 
     /** Tasks who need to acknowledge a checkpoint before it succeeds. */
-    private final Map<ExecutionAttemptID, ExecutionVertex> tasksToWaitFor;
+    private final List<Execution> tasksToWaitFor;
 
     /**
      * Tasks that are still running when taking the checkpoint, these need to be sent a message when
@@ -45,25 +44,43 @@ class CheckpointPlan {
      */
     private final List<ExecutionVertex> tasksToCommitTo;
 
+    /** Tasks that have already been finished when taking the checkpoint. */
+    private final List<Execution> finishedTasks;
+
+    /** The job vertices whose tasks are all finished when taking the checkpoint. */
+    private final List<ExecutionJobVertex> fullyFinishedJobVertex;
+
     CheckpointPlan(
             List<Execution> tasksToTrigger,
-            Map<ExecutionAttemptID, ExecutionVertex> tasksToWaitFor,
-            List<ExecutionVertex> tasksToCommitTo) {
+            List<Execution> tasksToWaitFor,
+            List<ExecutionVertex> tasksToCommitTo,
+            List<Execution> finishedTasks,
+            List<ExecutionJobVertex> fullyFinishedJobVertex) {
 
         this.tasksToTrigger = checkNotNull(tasksToTrigger);
         this.tasksToWaitFor = checkNotNull(tasksToWaitFor);
         this.tasksToCommitTo = checkNotNull(tasksToCommitTo);
+        this.finishedTasks = checkNotNull(finishedTasks);
+        this.fullyFinishedJobVertex = checkNotNull(fullyFinishedJobVertex);
     }
 
     List<Execution> getTasksToTrigger() {
         return tasksToTrigger;
     }
 
-    Map<ExecutionAttemptID, ExecutionVertex> getTasksToWaitFor() {
+    List<Execution> getTasksToWaitFor() {
         return tasksToWaitFor;
     }
 
     List<ExecutionVertex> getTasksToCommitTo() {
         return tasksToCommitTo;
     }
+
+    public List<Execution> getFinishedTasks() {
+        return finishedTasks;
+    }
+
+    public List<ExecutionJobVertex> getFullyFinishedJobVertex() {
+        return fullyFinishedJobVertex;
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculator.java
index 806afc5..b49581f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculator.java
@@ -18,109 +18,18 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.Execution;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import java.util.concurrent.CompletableFuture;
 
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-
-/** Computes the tasks to trigger, wait or commit for each checkpoint. */
-public class CheckpointPlanCalculator {
-    private static final Logger LOG = LoggerFactory.getLogger(CheckpointPlanCalculator.class);
-
-    private final JobID jobId;
-
-    private final List<ExecutionVertex> tasksToTrigger;
-
-    private final List<ExecutionVertex> tasksToWait;
-
-    private final List<ExecutionVertex> tasksToCommitTo;
-
-    public CheckpointPlanCalculator(
-            JobID jobId,
-            List<ExecutionVertex> tasksToTrigger,
-            List<ExecutionVertex> tasksToWait,
-            List<ExecutionVertex> tasksToCommitTo) {
-
-        this.jobId = jobId;
-        this.tasksToTrigger = Collections.unmodifiableList(tasksToTrigger);
-        this.tasksToWait = Collections.unmodifiableList(tasksToWait);
-        this.tasksToCommitTo = Collections.unmodifiableList(tasksToCommitTo);
-    }
-
-    public CheckpointPlan calculateCheckpointPlan() throws CheckpointException {
-        return new CheckpointPlan(
-                Collections.unmodifiableList(getTriggerExecutions()),
-                Collections.unmodifiableMap(getAckTasks()),
-                tasksToCommitTo);
-    }
-
-    /**
-     * Check if all tasks that we need to trigger are running. If not, abort the checkpoint.
-     *
-     * @return the executions need to be triggered.
-     * @throws CheckpointException the exception fails checking
-     */
-    private List<Execution> getTriggerExecutions() throws CheckpointException {
-        List<Execution> executionsToTrigger = new ArrayList<>(tasksToTrigger.size());
-        for (ExecutionVertex executionVertex : tasksToTrigger) {
-            Execution ee = executionVertex.getCurrentExecutionAttempt();
-            if (ee == null) {
-                LOG.info(
-                        "Checkpoint triggering task {} of job {} is not being executed at the moment. Aborting checkpoint.",
-                        executionVertex.getTaskNameWithSubtaskIndex(),
-                        executionVertex.getJobId());
-                throw new CheckpointException(
-                        CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
-            } else if (ee.getState() == ExecutionState.RUNNING) {
-                executionsToTrigger.add(ee);
-            } else {
-                LOG.info(
-                        "Checkpoint triggering task {} of job {} is not in state {} but {} instead. Aborting checkpoint.",
-                        executionVertex.getTaskNameWithSubtaskIndex(),
-                        jobId,
-                        ExecutionState.RUNNING,
-                        ee.getState());
-                throw new CheckpointException(
-                        CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
-            }
-        }
-
-        return executionsToTrigger;
-    }
+/**
+ * Calculates the plan of the next checkpoint, including the tasks to trigger, wait or commit for
+ * each checkpoint.
+ */
+public interface CheckpointPlanCalculator {
 
     /**
-     * Check if all tasks that need to acknowledge the checkpoint are running. If not, abort the
-     * checkpoint
+     * Calculates the plan of the next checkpoint.
      *
-     * @return the execution vertices which should give an ack response
-     * @throws CheckpointException the exception fails checking
+     * @return The result plan.
      */
-    private Map<ExecutionAttemptID, ExecutionVertex> getAckTasks() throws CheckpointException {
-        Map<ExecutionAttemptID, ExecutionVertex> ackTasks = new HashMap<>(tasksToWait.size());
-
-        for (ExecutionVertex ev : tasksToWait) {
-            Execution ee = ev.getCurrentExecutionAttempt();
-            if (ee != null) {
-                ackTasks.put(ee.getAttemptId(), ev);
-            } else {
-                LOG.info(
-                        "Checkpoint acknowledging task {} of job {} is not being executed at the moment. Aborting checkpoint.",
-                        ev.getTaskNameWithSubtaskIndex(),
-                        jobId);
-                throw new CheckpointException(
-                        CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
-            }
-        }
-        return ackTasks;
-    }
+    CompletableFuture<CheckpointPlan> calculateCheckpointPlan();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculatorContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculatorContext.java
new file mode 100644
index 0000000..e6238c4
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointPlanCalculatorContext.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
+
+/**
+ * Provides the context for {@link DefaultCheckpointPlanCalculator} to compute the plan of
+ * checkpoints.
+ */
+public interface CheckpointPlanCalculatorContext {
+
+    /**
+     * Acquires the main thread executor for this job.
+     *
+     * @return The main thread executor.
+     */
+    ScheduledExecutor getMainExecutor();
+
+    /**
+     * Detects whether there are already some tasks finished.
+     *
+     * @return Whether there are finished tasks.
+     */
+    boolean hasFinishedTasks();
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculator.java
new file mode 100644
index 0000000..75d9e0e
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculator.java
@@ -0,0 +1,342 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionEdge;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.JobEdge;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.BitSet;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.stream.Collectors;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * Default implementation for {@link CheckpointPlanCalculator}. If all tasks are running, it
+ * directly marks all the sources as tasks to trigger, otherwise it would try to find the running
+ * tasks without running processors as tasks to trigger.
+ */
+public class DefaultCheckpointPlanCalculator implements CheckpointPlanCalculator {
+
+    private final JobID jobId;
+
+    private final CheckpointPlanCalculatorContext context;
+
+    private final List<ExecutionJobVertex> jobVerticesInTopologyOrder = new ArrayList<>();
+
+    private final List<ExecutionVertex> allTasks = new ArrayList<>();
+
+    private final List<ExecutionVertex> sourceTasks = new ArrayList<>();
+
+    /**
+     * TODO Temporary flag to allow checkpoints after tasks finished. This is disabled for regular
+     * jobs to keep the current behavior but we want to allow it in tests. This should be removed
+     * once all parts of the stack support checkpoints after some tasks finished.
+     */
+    private boolean allowCheckpointsAfterTasksFinished;
+
+    public DefaultCheckpointPlanCalculator(
+            JobID jobId,
+            CheckpointPlanCalculatorContext context,
+            Iterable<ExecutionJobVertex> jobVerticesInTopologyOrderIterable) {
+
+        this.jobId = checkNotNull(jobId);
+        this.context = checkNotNull(context);
+
+        checkNotNull(jobVerticesInTopologyOrderIterable);
+        jobVerticesInTopologyOrderIterable.forEach(
+                jobVertex -> {
+                    jobVerticesInTopologyOrder.add(jobVertex);
+                    allTasks.addAll(Arrays.asList(jobVertex.getTaskVertices()));
+
+                    if (jobVertex.getJobVertex().isInputVertex()) {
+                        sourceTasks.addAll(Arrays.asList(jobVertex.getTaskVertices()));
+                    }
+                });
+    }
+
+    public void setAllowCheckpointsAfterTasksFinished(boolean allowCheckpointsAfterTasksFinished) {
+        this.allowCheckpointsAfterTasksFinished = allowCheckpointsAfterTasksFinished;
+    }
+
+    @Override
+    public CompletableFuture<CheckpointPlan> calculateCheckpointPlan() {
+        return CompletableFuture.supplyAsync(
+                () -> {
+                    try {
+                        if (context.hasFinishedTasks() && !allowCheckpointsAfterTasksFinished) {
+                            throw new CheckpointException(
+                                    String.format(
+                                            "some tasks of job %s has been finished, abort the checkpoint",
+                                            jobId),
+                                    CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
+                        }
+
+                        checkAllTasksInitiated();
+
+                        CheckpointPlan result =
+                                context.hasFinishedTasks()
+                                        ? calculateAfterTasksFinished()
+                                        : calculateWithAllTasksRunning();
+
+                        checkTasksStarted(result.getTasksToTrigger());
+
+                        return result;
+                    } catch (Throwable throwable) {
+                        throw new CompletionException(throwable);
+                    }
+                },
+                context.getMainExecutor());
+    }
+
+    /**
+     * Checks if all tasks are attached with the current Execution already. This method should be
+     * called from JobMaster main thread executor.
+     *
+     * @throws CheckpointException if some tasks do not have attached Execution.
+     */
+    private void checkAllTasksInitiated() throws CheckpointException {
+        for (ExecutionVertex task : allTasks) {
+            if (task.getCurrentExecutionAttempt() == null) {
+                throw new CheckpointException(
+                        String.format(
+                                "task %s of job %s is not being executed at the moment. Aborting checkpoint.",
+                                task.getTaskNameWithSubtaskIndex(), jobId),
+                        CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
+            }
+        }
+    }
+
+    /**
+     * Checks if all tasks to trigger have already been in RUNNING state. This method should be
+     * called from JobMaster main thread executor.
+     *
+     * @throws CheckpointException if some tasks to trigger have not turned into RUNNING yet.
+     */
+    private void checkTasksStarted(List<Execution> toTrigger) throws CheckpointException {
+        for (Execution execution : toTrigger) {
+            if (execution.getState() == ExecutionState.CREATED
+                    || execution.getState() == ExecutionState.SCHEDULED
+                    || execution.getState() == ExecutionState.DEPLOYING) {
+
+                throw new CheckpointException(
+                        String.format(
+                                "Checkpoint triggering task %s of job %s has not being executed at the moment. "
+                                        + "Aborting checkpoint.",
+                                execution.getVertex().getTaskNameWithSubtaskIndex(), jobId),
+                        CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING);
+            }
+        }
+    }
+
+    /**
+     * Computes the checkpoint plan when all tasks are running. It would simply marks all the source
+     * tasks as need to trigger and all the tasks as need to wait and commit.
+     *
+     * @return The plan of this checkpoint.
+     */
+    private CheckpointPlan calculateWithAllTasksRunning() {
+        List<Execution> executionsToTrigger =
+                sourceTasks.stream()
+                        .map(ExecutionVertex::getCurrentExecutionAttempt)
+                        .collect(Collectors.toList());
+
+        List<Execution> tasksToWaitFor = createTaskToWaitFor(allTasks);
+
+        return new CheckpointPlan(
+                Collections.unmodifiableList(executionsToTrigger),
+                Collections.unmodifiableList(tasksToWaitFor),
+                Collections.unmodifiableList(allTasks),
+                Collections.emptyList(),
+                Collections.emptyList());
+    }
+
+    /**
+     * Calculates the checkpoint plan after some tasks have finished. We iterate the job graph to
+     * find the task that is still running, but do not has precedent running tasks.
+     *
+     * @return The plan of this checkpoint.
+     */
+    private CheckpointPlan calculateAfterTasksFinished() {
+        // First collect the task running status into BitSet so that we could
+        // do JobVertex level judgement for some vertices and avoid time-consuming
+        // access to volatile isFinished flag of Execution.
+        Map<JobVertexID, BitSet> taskRunningStatusByVertex = collectTaskRunningStatus();
+
+        List<Execution> tasksToTrigger = new ArrayList<>();
+        List<Execution> tasksToWaitFor = new ArrayList<>();
+        List<ExecutionVertex> tasksToCommitTo = new ArrayList<>();
+        List<Execution> finishedTasks = new ArrayList<>();
+        List<ExecutionJobVertex> fullyFinishedJobVertex = new ArrayList<>();
+
+        for (ExecutionJobVertex jobVertex : jobVerticesInTopologyOrder) {
+            BitSet taskRunningStatus = taskRunningStatusByVertex.get(jobVertex.getJobVertexId());
+
+            if (taskRunningStatus.cardinality() == 0) {
+                fullyFinishedJobVertex.add(jobVertex);
+
+                for (ExecutionVertex task : jobVertex.getTaskVertices()) {
+                    finishedTasks.add(task.getCurrentExecutionAttempt());
+                }
+
+                continue;
+            }
+
+            List<JobEdge> prevJobEdges = jobVertex.getJobVertex().getInputs();
+
+            // this is an optimization: we determine at the JobVertex level if some tasks can even
+            // be eligible for being in the "triggerTo" set.
+            boolean someTasksMustBeTriggered =
+                    someTasksMustBeTriggered(taskRunningStatusByVertex, prevJobEdges);
+
+            for (int i = 0; i < jobVertex.getTaskVertices().length; ++i) {
+                ExecutionVertex task = jobVertex.getTaskVertices()[i];
+                if (taskRunningStatus.get(task.getParallelSubtaskIndex())) {
+                    tasksToWaitFor.add(task.getCurrentExecutionAttempt());
+                    tasksToCommitTo.add(task);
+
+                    if (someTasksMustBeTriggered) {
+                        boolean hasRunningPrecedentTasks =
+                                hasRunningPrecedentTasks(
+                                        task, prevJobEdges, taskRunningStatusByVertex);
+
+                        if (!hasRunningPrecedentTasks) {
+                            tasksToTrigger.add(task.getCurrentExecutionAttempt());
+                        }
+                    }
+                } else {
+                    finishedTasks.add(task.getCurrentExecutionAttempt());
+                }
+            }
+        }
+
+        return new CheckpointPlan(
+                Collections.unmodifiableList(tasksToTrigger),
+                Collections.unmodifiableList(tasksToWaitFor),
+                Collections.unmodifiableList(tasksToCommitTo),
+                Collections.unmodifiableList(finishedTasks),
+                Collections.unmodifiableList(fullyFinishedJobVertex));
+    }
+
+    private boolean someTasksMustBeTriggered(
+            Map<JobVertexID, BitSet> runningTasksByVertex, List<JobEdge> prevJobEdges) {
+
+        for (JobEdge jobEdge : prevJobEdges) {
+            DistributionPattern distributionPattern = jobEdge.getDistributionPattern();
+            BitSet upstreamRunningStatus =
+                    runningTasksByVertex.get(jobEdge.getSource().getProducer().getID());
+
+            if (hasActiveUpstreamVertex(distributionPattern, upstreamRunningStatus)) {
+                return false;
+            }
+        }
+
+        return true;
+    }
+
+    /**
+     * Every task must have active upstream tasks if
+     *
+     * <ol>
+     *   <li>ALL_TO_ALL connection and some predecessors are still running.
+     *   <li>POINTWISE connection and all predecessors are still running.
+     * </ol>
+     *
+     * @param distribution The distribution pattern between the upstream vertex and the current
+     *     vertex.
+     * @param upstreamRunningTasks The running tasks of the upstream vertex.
+     * @return Whether every task of the current vertex is connected to some active predecessors.
+     */
+    private boolean hasActiveUpstreamVertex(
+            DistributionPattern distribution, BitSet upstreamRunningTasks) {
+        return (distribution == DistributionPattern.ALL_TO_ALL
+                        && upstreamRunningTasks.cardinality() > 0)
+                || (distribution == DistributionPattern.POINTWISE
+                        && upstreamRunningTasks.cardinality() == upstreamRunningTasks.size());
+    }
+
+    private boolean hasRunningPrecedentTasks(
+            ExecutionVertex vertex,
+            List<JobEdge> prevJobEdges,
+            Map<JobVertexID, BitSet> taskRunningStatusByVertex) {
+        for (int i = 0; i < prevJobEdges.size(); ++i) {
+            if (prevJobEdges.get(i).getDistributionPattern() == DistributionPattern.POINTWISE) {
+                for (ExecutionEdge executionEdge : vertex.getInputEdges(i)) {
+                    ExecutionVertex precedentTask = executionEdge.getSource().getProducer();
+                    BitSet precedentVertexRunningStatus =
+                            taskRunningStatusByVertex.get(precedentTask.getJobvertexId());
+
+                    if (precedentVertexRunningStatus.get(precedentTask.getParallelSubtaskIndex())) {
+                        return true;
+                    }
+                }
+            }
+        }
+
+        return false;
+    }
+
+    /**
+     * Collects the task running status for each job vertex.
+     *
+     * @return The task running status for each job vertex.
+     */
+    @VisibleForTesting
+    Map<JobVertexID, BitSet> collectTaskRunningStatus() {
+        Map<JobVertexID, BitSet> runningStatusByVertex = new HashMap<>();
+
+        for (ExecutionJobVertex vertex : jobVerticesInTopologyOrder) {
+            BitSet runningTasks = new BitSet(vertex.getTaskVertices().length);
+
+            for (int i = 0; i < vertex.getTaskVertices().length; ++i) {
+                if (!vertex.getTaskVertices()[i].getCurrentExecutionAttempt().isFinished()) {
+                    runningTasks.set(i);
+                }
+            }
+
+            runningStatusByVertex.put(vertex.getJobVertexId(), runningTasks);
+        }
+
+        return runningStatusByVertex;
+    }
+
+    private List<Execution> createTaskToWaitFor(List<ExecutionVertex> tasks) {
+        List<Execution> tasksToAck = new ArrayList<>(tasks.size());
+        for (ExecutionVertex task : tasks) {
+            tasksToAck.add(task.getCurrentExecutionAttempt());
+        }
+
+        return tasksToAck;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ExecutionAttemptMappingProvider.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ExecutionAttemptMappingProvider.java
index 14ab9a5..a7f7f83 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ExecutionAttemptMappingProvider.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/ExecutionAttemptMappingProvider.java
@@ -21,14 +21,13 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
-
 /**
  * Provides a mapping from {@link ExecutionAttemptID} to {@link ExecutionVertex} for currently
  * running execution attempts.
@@ -41,8 +40,9 @@ public class ExecutionAttemptMappingProvider {
     /** The cached mapping, which would only be updated on miss. */
     private final LinkedHashMap<ExecutionAttemptID, ExecutionVertex> cachedTasksById;
 
-    public ExecutionAttemptMappingProvider(List<ExecutionVertex> tasks) {
-        this.tasks = checkNotNull(tasks);
+    public ExecutionAttemptMappingProvider(Iterable<ExecutionVertex> tasksIterable) {
+        this.tasks = new ArrayList<>();
+        tasksIterable.forEach(this.tasks::add);
 
         this.cachedTasksById =
                 new LinkedHashMap<ExecutionAttemptID, ExecutionVertex>(tasks.size()) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
index e44d2fa..196c363 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/PendingCheckpoint.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.OperatorIDPair;
 import org.apache.flink.runtime.checkpoint.metadata.CheckpointMetadata;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -141,7 +142,12 @@ public class PendingCheckpoint implements Checkpoint {
         this.checkpointId = checkpointId;
         this.checkpointTimestamp = checkpointTimestamp;
         this.checkpointPlan = checkNotNull(checkpointPlan);
-        this.notYetAcknowledgedTasks = new HashMap<>(checkpointPlan.getTasksToWaitFor());
+
+        this.notYetAcknowledgedTasks = new HashMap<>(checkpointPlan.getTasksToWaitFor().size());
+        for (Execution execution : checkpointPlan.getTasksToWaitFor()) {
+            notYetAcknowledgedTasks.put(execution.getAttemptId(), execution.getVertex());
+        }
+
         this.props = checkNotNull(props);
         this.targetLocation = checkNotNull(targetLocation);
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskStateStats.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskStateStats.java
index bd7719b..0347d2f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskStateStats.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/SubtaskStateStats.java
@@ -64,6 +64,10 @@ public class SubtaskStateStats implements Serializable {
     /** Is the checkpoint completed by this subtask. */
     private final boolean completed;
 
+    SubtaskStateStats(int subtaskIndex, long ackTimestamp) {
+        this(subtaskIndex, ackTimestamp, 0, 0, 0, 0, 0, 0, 0, false, true);
+    }
+
     SubtaskStateStats(
             int subtaskIndex,
             long ackTimestamp,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 4db17fd..d46fb7f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -896,7 +896,7 @@ public class Execution
     }
 
     @VisibleForTesting
-    void markFinished() {
+    public void markFinished() {
         markFinished(null, null);
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 9e0f6d5..72910ff 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -26,7 +26,6 @@ import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.api.common.accumulators.Accumulator;
 import org.apache.flink.api.common.accumulators.AccumulatorHelper;
 import org.apache.flink.api.common.time.Time;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.SimpleCounter;
@@ -43,6 +42,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointStatsSnapshot;
 import org.apache.flink.runtime.checkpoint.CheckpointStatsTracker;
 import org.apache.flink.runtime.checkpoint.CheckpointsCleaner;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.checkpoint.DefaultCheckpointPlanCalculator;
 import org.apache.flink.runtime.checkpoint.ExecutionAttemptMappingProvider;
 import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
 import org.apache.flink.runtime.checkpoint.OperatorCoordinatorCheckpointContext;
@@ -223,7 +223,7 @@ public class ExecutionGraph implements AccessExecutionGraph {
     // ------ Execution status and progress. These values are volatile, and accessed under the lock
     // -------
 
-    private int verticesFinished;
+    private int numFinishedVertices;
 
     /** Current status of the job execution. */
     private volatile JobStatus state = JobStatus.CREATED;
@@ -443,9 +443,6 @@ public class ExecutionGraph implements AccessExecutionGraph {
                         new DispatcherThreadFactory(
                                 Thread.currentThread().getThreadGroup(), "Checkpoint Timer"));
 
-        Tuple2<List<ExecutionVertex>, List<ExecutionVertex>> sourceAndAllVertices =
-                getSourceAndAllVertices();
-
         // create the coordinator that triggers and commits checkpoints and holds the state
         checkpointCoordinator =
                 new CheckpointCoordinator(
@@ -460,12 +457,8 @@ public class ExecutionGraph implements AccessExecutionGraph {
                         new ScheduledExecutorServiceAdapter(checkpointCoordinatorTimer),
                         SharedStateRegistry.DEFAULT_FACTORY,
                         failureManager,
-                        new CheckpointPlanCalculator(
-                                getJobID(),
-                                sourceAndAllVertices.f0,
-                                sourceAndAllVertices.f1,
-                                sourceAndAllVertices.f1),
-                        new ExecutionAttemptMappingProvider(sourceAndAllVertices.f1));
+                        createCheckpointPlanCalculator(),
+                        new ExecutionAttemptMappingProvider(getAllExecutionVertices()));
 
         // register the master hooks on the checkpoint coordinator
         for (MasterTriggerRestoreHook<?> hook : masterHooks) {
@@ -491,18 +484,11 @@ public class ExecutionGraph implements AccessExecutionGraph {
         this.checkpointStorageName = checkpointStorage.getClass().getSimpleName();
     }
 
-    private Tuple2<List<ExecutionVertex>, List<ExecutionVertex>> getSourceAndAllVertices() {
-        List<ExecutionVertex> sourceVertices = new ArrayList<>();
-        List<ExecutionVertex> allVertices = new ArrayList<>();
-        for (ExecutionVertex executionVertex : getAllExecutionVertices()) {
-            if (executionVertex.getJobVertex().getJobVertex().isInputVertex()) {
-                sourceVertices.add(executionVertex);
-            }
-
-            allVertices.add(executionVertex);
-        }
-
-        return new Tuple2<>(sourceVertices, allVertices);
+    private CheckpointPlanCalculator createCheckpointPlanCalculator() {
+        return new DefaultCheckpointPlanCalculator(
+                getJobID(),
+                new ExecutionGraphCheckpointPlanCalculatorContext(this),
+                getVerticesTopologically());
     }
 
     @Nullable
@@ -605,6 +591,10 @@ public class ExecutionGraph implements AccessExecutionGraph {
         return numberOfRestartsCounter.getCount();
     }
 
+    public int getNumFinishedVertices() {
+        return numFinishedVertices;
+    }
+
     @Override
     public ExecutionJobVertex getJobVertex(JobVertexID id) {
         return this.tasks.get(id);
@@ -1096,7 +1086,7 @@ public class ExecutionGraph implements AccessExecutionGraph {
      */
     void vertexFinished() {
         assertRunningInJobMasterMainThread();
-        final int numFinished = ++verticesFinished;
+        final int numFinished = ++numFinishedVertices;
         if (numFinished == numVerticesTotal) {
             // done :-)
 
@@ -1127,7 +1117,7 @@ public class ExecutionGraph implements AccessExecutionGraph {
 
     void vertexUnFinished() {
         assertRunningInJobMasterMainThread();
-        verticesFinished--;
+        numFinishedVertices--;
     }
 
     /**
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphCheckpointPlanCalculatorContext.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphCheckpointPlanCalculatorContext.java
new file mode 100644
index 0000000..83679b3
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraphCheckpointPlanCalculatorContext.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.runtime.checkpoint.CheckpointPlanCalculatorContext;
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
+
+import static org.apache.flink.util.Preconditions.checkNotNull;
+
+/**
+ * A {@link CheckpointPlanCalculatorContext} implementation based on the information from an {@link
+ * ExecutionGraph}.
+ */
+public class ExecutionGraphCheckpointPlanCalculatorContext
+        implements CheckpointPlanCalculatorContext {
+
+    private final ExecutionGraph executionGraph;
+
+    public ExecutionGraphCheckpointPlanCalculatorContext(ExecutionGraph executionGraph) {
+        this.executionGraph = checkNotNull(executionGraph);
+    }
+
+    @Override
+    public ScheduledExecutor getMainExecutor() {
+        return executionGraph.getJobMasterMainThreadExecutor();
+    }
+
+    @Override
+    public boolean hasFinishedTasks() {
+        return executionGraph.getNumFinishedVertices() > 0;
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index a267f8c..5fac8e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -18,14 +18,14 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
@@ -59,20 +59,22 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
      */
     @Test
     public void testFailingCompletedCheckpointStoreAdd() throws Exception {
-        JobID jid = new JobID();
+        JobVertexID jobVertexId = new JobVertexID();
 
         final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
                 new ManuallyTriggeredScheduledExecutor();
 
-        final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
-        final ExecutionVertex vertex =
-                CheckpointCoordinatorTestingUtils.mockExecutionVertex(executionAttemptId);
+        ExecutionGraph testGraph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexId)
+                        .build();
+
+        ExecutionVertex vertex = testGraph.getJobVertex(jobVertexId).getTaskVertices()[0];
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator coord =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jid)
-                        .setTasks(new ExecutionVertex[] {vertex})
+                        .setExecutionGraph(testGraph)
                         .setCompletedCheckpointStore(new FailingCompletedCheckpointStore())
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
@@ -128,8 +130,8 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 
         AcknowledgeCheckpoint acknowledgeMessage =
                 new AcknowledgeCheckpoint(
-                        jid,
-                        executionAttemptId,
+                        testGraph.getJobID(),
+                        vertex.getCurrentExecutionAttempt().getAttemptId(),
                         checkpointId,
                         new CheckpointMetrics(),
                         subtaskState);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index 0d14495..0dd7104 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -24,7 +24,9 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
@@ -39,7 +41,6 @@ import org.mockito.stubbing.Answer;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
-import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
@@ -48,7 +49,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.StringSerializer;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -73,8 +73,12 @@ public class CheckpointCoordinatorMasterHooksTest {
 
     /** This method tests that hooks with the same identifier are not registered multiple times. */
     @Test
-    public void testDeduplicateOnRegister() {
-        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID());
+    public void testDeduplicateOnRegister() throws Exception {
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
+        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(graph);
 
         MasterTriggerRestoreHook<?> hook1 = mock(MasterTriggerRestoreHook.class);
         when(hook1.getIdentifier()).thenReturn("test id");
@@ -92,8 +96,12 @@ public class CheckpointCoordinatorMasterHooksTest {
 
     /** Test that validates correct exceptions when supplying hooks with invalid IDs. */
     @Test
-    public void testNullOrInvalidId() {
-        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(new JobID());
+    public void testNullOrInvalidId() throws Exception {
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
+        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(graph);
 
         try {
             cc.addMasterHook(null);
@@ -128,10 +136,11 @@ public class CheckpointCoordinatorMasterHooksTest {
         when(hook2.getIdentifier()).thenReturn(id2);
 
         // create the checkpoint coordinator
-        final JobID jid = new JobID();
-        final ExecutionAttemptID execId = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex = mockExecutionVertex(execId);
-        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
+        CheckpointCoordinator cc = instantiateCheckpointCoordinator(graph);
 
         cc.addMasterHook(hook1);
         cc.addMasterHook(hook2);
@@ -181,14 +190,15 @@ public class CheckpointCoordinatorMasterHooksTest {
         when(statelessHook.getIdentifier()).thenReturn("some-id");
 
         // create the checkpoint coordinator
-        final JobID jid = new JobID();
-        final ExecutionAttemptID execId = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex = mockExecutionVertex(execId);
+        JobVertexID jobVertexId = new JobVertexID();
+        final ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexId)
+                        .build();
         final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
                 new ManuallyTriggeredScheduledExecutor();
         final CheckpointCoordinator cc =
-                instantiateCheckpointCoordinator(
-                        jid, manuallyTriggeredScheduledExecutor, ackVertex);
+                instantiateCheckpointCoordinator(graph, manuallyTriggeredScheduledExecutor);
 
         cc.addMasterHook(statefulHook1);
         cc.addMasterHook(statelessHook);
@@ -207,10 +217,17 @@ public class CheckpointCoordinatorMasterHooksTest {
         verify(statelessHook, times(1))
                 .triggerCheckpoint(anyLong(), anyLong(), any(Executor.class));
 
+        ExecutionAttemptID attemptID =
+                graph.getJobVertex(jobVertexId)
+                        .getTaskVertices()[0]
+                        .getCurrentExecutionAttempt()
+                        .getAttemptId();
+
         final long checkpointId =
                 cc.getPendingCheckpoints().values().iterator().next().getCheckpointId();
         cc.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jid, execId, checkpointId), "Unknown location");
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID, checkpointId),
+                "Unknown location");
         assertEquals(0, cc.getNumberOfPendingCheckpoints());
 
         assertEquals(1, cc.getNumberOfRetainedSuccessfulCheckpoints());
@@ -280,9 +297,11 @@ public class CheckpointCoordinatorMasterHooksTest {
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
                         new TestCompletedCheckpointStorageLocation());
-        final ExecutionAttemptID execId = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex = mockExecutionVertex(execId);
-        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
+        CheckpointCoordinator cc = instantiateCheckpointCoordinator(graph);
 
         cc.addMasterHook(statefulHook1);
         cc.addMasterHook(statelessHook);
@@ -338,9 +357,11 @@ public class CheckpointCoordinatorMasterHooksTest {
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
                         new TestCompletedCheckpointStorageLocation());
 
-        final ExecutionAttemptID execId = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex = mockExecutionVertex(execId);
-        final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
+        CheckpointCoordinator cc = instantiateCheckpointCoordinator(graph);
 
         cc.addMasterHook(statefulHook);
         cc.addMasterHook(statelessHook);
@@ -374,14 +395,14 @@ public class CheckpointCoordinatorMasterHooksTest {
         final String id = "id";
 
         // create the checkpoint coordinator
-        final JobID jid = new JobID();
-        final ExecutionAttemptID execId = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex = mockExecutionVertex(execId);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
         final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
                 new ManuallyTriggeredScheduledExecutor();
-        final CheckpointCoordinator cc =
-                instantiateCheckpointCoordinator(
-                        jid, manuallyTriggeredScheduledExecutor, ackVertex);
+        CheckpointCoordinator cc =
+                instantiateCheckpointCoordinator(graph, manuallyTriggeredScheduledExecutor);
 
         final MasterTriggerRestoreHook<Void> hook = mockGeneric(MasterTriggerRestoreHook.class);
         when(hook.getIdentifier()).thenReturn(id);
@@ -434,15 +455,14 @@ public class CheckpointCoordinatorMasterHooksTest {
     //  utilities
     // ------------------------------------------------------------------------
 
-    private CheckpointCoordinator instantiateCheckpointCoordinator(
-            JobID jid, ExecutionVertex... ackVertices) {
+    private CheckpointCoordinator instantiateCheckpointCoordinator(ExecutionGraph executionGraph) {
 
         return instantiateCheckpointCoordinator(
-                jid, new ManuallyTriggeredScheduledExecutor(), ackVertices);
+                executionGraph, new ManuallyTriggeredScheduledExecutor());
     }
 
     private CheckpointCoordinator instantiateCheckpointCoordinator(
-            JobID jid, ScheduledExecutor testingScheduledExecutor, ExecutionVertex... ackVertices) {
+            ExecutionGraph graph, ScheduledExecutor testingScheduledExecutor) {
 
         CheckpointCoordinatorConfiguration chkConfig =
                 new CheckpointCoordinatorConfiguration(
@@ -457,7 +477,7 @@ public class CheckpointCoordinatorMasterHooksTest {
                         0);
         Executor executor = Executors.directExecutor();
         return new CheckpointCoordinator(
-                jid,
+                graph.getJobID(),
                 chkConfig,
                 Collections.emptyList(),
                 new StandaloneCheckpointIDCounter(),
@@ -468,9 +488,11 @@ public class CheckpointCoordinatorMasterHooksTest {
                 testingScheduledExecutor,
                 SharedStateRegistry.DEFAULT_FACTORY,
                 new CheckpointFailureManager(0, NoOpFailJobCall.INSTANCE),
-                new CheckpointPlanCalculator(
-                        jid, new ArrayList<>(), Arrays.asList(ackVertices), new ArrayList<>()),
-                new ExecutionAttemptMappingProvider(Arrays.asList(ackVertices)));
+                new DefaultCheckpointPlanCalculator(
+                        graph.getJobID(),
+                        new ExecutionGraphCheckpointPlanCalculatorContext(graph),
+                        graph.getVerticesTopologically()),
+                new ExecutionAttemptMappingProvider(graph.getAllExecutionVertices()));
     }
 
     private static <T> T mockGeneric(Class<?> clazz) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
index 8fa53b8..4bc02b6 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
@@ -18,13 +18,13 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.OperatorIDPair;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
@@ -45,15 +45,11 @@ import org.apache.flink.util.TestLogger;
 
 import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
 
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.Description;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
-import org.mockito.Mockito;
-import org.mockito.hamcrest.MockitoHamcrest;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -63,23 +59,22 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
@@ -114,8 +109,6 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
      */
     @Test
     public void testRestoreLatestCheckpointedState() throws Exception {
-        final JobID jid = new JobID();
-
         final JobVertexID jobVertexID1 = new JobVertexID();
         final JobVertexID jobVertexID2 = new JobVertexID();
         int parallelism1 = 3;
@@ -123,26 +116,21 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
         int maxParallelism1 = 42;
         int maxParallelism2 = 13;
 
-        final ExecutionJobVertex jobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, maxParallelism1);
-        final ExecutionJobVertex jobVertex2 =
-                mockExecutionJobVertex(jobVertexID2, parallelism2, maxParallelism2);
-
-        List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
-
-        allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-        allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+        final ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, maxParallelism1)
+                        .addJobVertex(jobVertexID2, parallelism2, maxParallelism2)
+                        .build();
 
-        ExecutionVertex[] arrayExecutionVertices =
-                allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+        final ExecutionJobVertex jobVertex1 = graph.getJobVertex(jobVertexID1);
+        final ExecutionJobVertex jobVertex2 = graph.getJobVertex(jobVertexID2);
 
         CompletedCheckpointStore store = new EmbeddedCompletedCheckpointStore();
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator coord =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jid)
-                        .setTasks(arrayExecutionVertices)
+                        .setExecutionGraph(graph)
                         .setCompletedCheckpointStore(store)
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
@@ -165,7 +153,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex1
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -183,7 +171,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex2
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -247,22 +235,25 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
     private void testRestoreLatestCheckpointIsPreferSavepoint(boolean isPreferCheckpoint) {
         try {
-            final JobID jid = new JobID();
             StandaloneCheckpointIDCounter checkpointIDCounter = new StandaloneCheckpointIDCounter();
 
             final JobVertexID statefulId = new JobVertexID();
             final JobVertexID statelessId = new JobVertexID();
 
-            Execution statefulExec1 = mockExecution();
-            Execution statelessExec1 = mockExecution();
+            final ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(statefulId)
+                            .addJobVertex(statelessId)
+                            .build();
+
+            ExecutionJobVertex stateful = graph.getJobVertex(statefulId);
+            ExecutionJobVertex stateless = graph.getJobVertex(statelessId);
 
-            ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 1);
-            ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 1);
+            ExecutionVertex stateful1 = stateful.getTaskVertices()[0];
+            ExecutionVertex stateless1 = stateless.getTaskVertices()[0];
 
-            ExecutionJobVertex stateful =
-                    mockExecutionJobVertex(statefulId, new ExecutionVertex[] {stateful1});
-            ExecutionJobVertex stateless =
-                    mockExecutionJobVertex(statelessId, new ExecutionVertex[] {stateless1});
+            Execution statefulExec1 = stateful1.getCurrentExecutionAttempt();
+            Execution statelessExec1 = stateless1.getCurrentExecutionAttempt();
 
             Set<ExecutionJobVertex> tasks = new HashSet<>();
             tasks.add(stateful);
@@ -276,11 +267,10 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
                             .build();
             CheckpointCoordinator coord =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jid)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(chkConfig)
                             .setCheckpointIDCounter(checkpointIDCounter)
                             .setCompletedCheckpointStore(store)
-                            .setTasks(new ExecutionVertex[] {stateful1, stateless1})
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
 
@@ -308,18 +298,19 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
             coord.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             statefulExec1.getAttemptId(),
                             checkpointId,
                             new CheckpointMetrics(),
                             subtaskStatesForCheckpoint),
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId),
+                    new AcknowledgeCheckpoint(
+                            graph.getJobID(), statelessExec1.getAttemptId(), checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
 
             CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0);
-            assertEquals(jid, success.getJobId());
+            assertEquals(graph.getJobID(), success.getJobId());
 
             // trigger a savepoint and wait it to be finished
             String savepointDir = tmpFolder.newFolder().getAbsolutePath();
@@ -349,14 +340,15 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
             checkpointId = checkpointIDCounter.getLast();
             coord.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             statefulExec1.getAttemptId(),
                             checkpointId,
                             new CheckpointMetrics(),
                             subtaskStatesForSavepoint),
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId),
+                    new AcknowledgeCheckpoint(
+                            graph.getJobID(), statelessExec1.getAttemptId(), checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
 
             assertNotNull(savepointFuture.get());
@@ -365,37 +357,20 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
             assertTrue(coord.restoreLatestCheckpointedStateToAll(tasks, false));
 
             // compare and see if it used the checkpoint's subtaskStates
-            BaseMatcher<JobManagerTaskRestore> matcher =
-                    new BaseMatcher<JobManagerTaskRestore>() {
-                        @Override
-                        public boolean matches(Object o) {
-                            if (o instanceof JobManagerTaskRestore) {
-                                JobManagerTaskRestore taskRestore = (JobManagerTaskRestore) o;
-                                if (isPreferCheckpoint) {
-                                    return Objects.equals(
-                                            taskRestore.getTaskStateSnapshot(),
-                                            subtaskStatesForCheckpoint);
-                                } else {
-                                    return Objects.equals(
-                                            taskRestore.getTaskStateSnapshot(),
-                                            subtaskStatesForSavepoint);
-                                }
-                            }
-                            return false;
-                        }
-
-                        @Override
-                        public void describeTo(Description description) {
-                            if (isPreferCheckpoint) {
-                                description.appendValue(subtaskStatesForCheckpoint);
-                            } else {
-                                description.appendValue(subtaskStatesForSavepoint);
-                            }
-                        }
-                    };
-
-            verify(statefulExec1, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
-            verify(statelessExec1, times(0)).setInitialState(Mockito.<JobManagerTaskRestore>any());
+            assertNotNull(
+                    "Stateful vertex should get state to restore", statefulExec1.getTaskRestore());
+            if (isPreferCheckpoint) {
+                assertEquals(
+                        subtaskStatesForCheckpoint,
+                        statefulExec1.getTaskRestore().getTaskStateSnapshot());
+            } else {
+                assertEquals(
+                        subtaskStatesForSavepoint,
+                        statefulExec1.getTaskRestore().getTaskStateSnapshot());
+            }
+            assertNull(
+                    "Stateless vertex should not get state to restore",
+                    statelessExec1.getTaskRestore());
 
             coord.shutdown();
         } catch (Exception e) {
@@ -412,8 +387,6 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
      */
     private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut)
             throws Exception {
-        final JobID jid = new JobID();
-
         final JobVertexID jobVertexID1 = new JobVertexID();
         final JobVertexID jobVertexID2 = new JobVertexID();
         int parallelism1 = 3;
@@ -424,24 +397,22 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
         int newParallelism2 = scaleOut ? 13 : 2;
 
-        final ExecutionJobVertex jobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, maxParallelism1);
-        final ExecutionJobVertex jobVertex2 =
-                mockExecutionJobVertex(jobVertexID2, parallelism2, maxParallelism2);
-
-        List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+        CompletedCheckpointStore completedCheckpointStore = new EmbeddedCompletedCheckpointStore();
 
-        allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-        allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+        final ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, maxParallelism1)
+                        .addJobVertex(jobVertexID2, parallelism2, maxParallelism2)
+                        .build();
 
-        ExecutionVertex[] arrayExecutionVertices =
-                allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+        final ExecutionJobVertex jobVertex1 = graph.getJobVertex(jobVertexID1);
+        final ExecutionJobVertex jobVertex2 = graph.getJobVertex(jobVertexID2);
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator coord =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jid)
-                        .setTasks(arrayExecutionVertices)
+                        .setExecutionGraph(graph)
+                        .setCompletedCheckpointStore(completedCheckpointStore)
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
 
@@ -477,7 +448,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex1
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -521,7 +492,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex2
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -537,21 +508,31 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
         assertEquals(1, completedCheckpoints.size());
 
-        Set<ExecutionJobVertex> tasks = new HashSet<>();
-
         List<KeyGroupRange> newKeyGroupPartitions2 =
                 StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
 
-        final ExecutionJobVertex newJobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, maxParallelism1);
-
         // rescale vertex 2
-        final ExecutionJobVertex newJobVertex2 =
-                mockExecutionJobVertex(jobVertexID2, newParallelism2, maxParallelism2);
+        final ExecutionGraph newGraph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, maxParallelism1)
+                        .addJobVertex(jobVertexID2, newParallelism2, maxParallelism2)
+                        .build();
 
+        final ExecutionJobVertex newJobVertex1 = newGraph.getJobVertex(jobVertexID1);
+        final ExecutionJobVertex newJobVertex2 = newGraph.getJobVertex(jobVertexID2);
+
+        // set up the coordinator and validate the initial state
+        CheckpointCoordinator newCoord =
+                new CheckpointCoordinatorBuilder()
+                        .setExecutionGraph(newGraph)
+                        .setCompletedCheckpointStore(completedCheckpointStore)
+                        .setTimer(manuallyTriggeredScheduledExecutor)
+                        .build();
+
+        Set<ExecutionJobVertex> tasks = new HashSet<>();
         tasks.add(newJobVertex1);
         tasks.add(newJobVertex2);
-        assertTrue(coord.restoreLatestCheckpointedStateToAll(tasks, false));
+        assertTrue(newCoord.restoreLatestCheckpointedStateToAll(tasks, false));
 
         // verify the restored state
         verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
@@ -616,8 +597,6 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
      */
     @Test(expected = IllegalStateException.class)
     public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
-        final JobID jid = new JobID();
-
         final JobVertexID jobVertexID1 = new JobVertexID();
         final JobVertexID jobVertexID2 = new JobVertexID();
         int parallelism1 = 3;
@@ -625,24 +604,21 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
         int maxParallelism1 = 42;
         int maxParallelism2 = 13;
 
-        final ExecutionJobVertex jobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, maxParallelism1);
-        final ExecutionJobVertex jobVertex2 =
-                mockExecutionJobVertex(jobVertexID2, parallelism2, maxParallelism2);
-
-        List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+        CompletedCheckpointStore completedCheckpointStore = new EmbeddedCompletedCheckpointStore();
 
-        allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-        allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
-
-        ExecutionVertex[] arrayExecutionVertices =
-                allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, maxParallelism1)
+                        .addJobVertex(jobVertexID2, parallelism2, maxParallelism2)
+                        .build();
+        ExecutionJobVertex jobVertex1 = graph.getJobVertex(jobVertexID1);
+        ExecutionJobVertex jobVertex2 = graph.getJobVertex(jobVertexID2);
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator coord =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jid)
-                        .setTasks(arrayExecutionVertices)
+                        .setExecutionGraph(graph)
+                        .setCompletedCheckpointStore(completedCheckpointStore)
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
 
@@ -668,7 +644,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
                     OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex1
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -690,7 +666,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
                     OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
             AcknowledgeCheckpoint acknowledgeCheckpoint =
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             jobVertex2
                                     .getTaskVertices()[index]
                                     .getCurrentExecutionAttempt()
@@ -706,21 +682,30 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 
         assertEquals(1, completedCheckpoints.size());
 
-        Set<ExecutionJobVertex> tasks = new HashSet<>();
-
         int newMaxParallelism1 = 20;
         int newMaxParallelism2 = 42;
 
-        final ExecutionJobVertex newJobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, newMaxParallelism1);
+        ExecutionGraph newGraph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, newMaxParallelism1)
+                        .addJobVertex(jobVertexID2, parallelism2, newMaxParallelism2)
+                        .build();
+
+        ExecutionJobVertex newJobVertex1 = newGraph.getJobVertex(jobVertexID1);
+        ExecutionJobVertex newJobVertex2 = newGraph.getJobVertex(jobVertexID2);
 
-        final ExecutionJobVertex newJobVertex2 =
-                mockExecutionJobVertex(jobVertexID2, parallelism2, newMaxParallelism2);
+        // set up the coordinator and validate the initial state
+        CheckpointCoordinator newCoord =
+                new CheckpointCoordinatorBuilder()
+                        .setExecutionGraph(newGraph)
+                        .setCompletedCheckpointStore(completedCheckpointStore)
+                        .setTimer(manuallyTriggeredScheduledExecutor)
+                        .build();
 
+        Set<ExecutionJobVertex> tasks = new HashSet<>();
         tasks.add(newJobVertex1);
         tasks.add(newJobVertex2);
-
-        assertTrue(coord.restoreLatestCheckpointedStateToAll(tasks, false));
+        assertTrue(newCoord.restoreLatestCheckpointedStateToAll(tasks, false));
 
         fail("The restoration should have failed because the max parallelism changed.");
     }
@@ -854,27 +839,37 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
         List<KeyGroupRange> newKeyGroupPartitions2 =
                 StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
 
-        final ExecutionJobVertex newJobVertex1 =
-                mockExecutionJobVertex(
-                        id5.f0,
-                        Arrays.asList(id2.f1, id1.f1, id5.f1),
-                        newParallelism1,
-                        maxParallelism1);
+        ExecutionGraph newGraph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(
+                                id5.f0,
+                                newParallelism1,
+                                maxParallelism1,
+                                Stream.of(id2.f1, id1.f1, id5.f1)
+                                        .map(OperatorIDPair::generatedIDOnly)
+                                        .collect(Collectors.toList()),
+                                true)
+                        .addJobVertex(
+                                id3.f0,
+                                newParallelism2,
+                                maxParallelism2,
+                                Stream.of(id6.f1, id3.f1)
+                                        .map(OperatorIDPair::generatedIDOnly)
+                                        .collect(Collectors.toList()),
+                                true)
+                        .build();
 
-        final ExecutionJobVertex newJobVertex2 =
-                mockExecutionJobVertex(
-                        id3.f0, Arrays.asList(id6.f1, id3.f1), newParallelism2, maxParallelism2);
+        ExecutionJobVertex newJobVertex1 = newGraph.getJobVertex(id5.f0);
+        ExecutionJobVertex newJobVertex2 = newGraph.getJobVertex(id3.f0);
 
         Set<ExecutionJobVertex> tasks = new HashSet<>();
 
         tasks.add(newJobVertex1);
         tasks.add(newJobVertex2);
 
-        JobID jobID = new JobID();
-
         CompletedCheckpoint completedCheckpoint =
                 new CompletedCheckpoint(
-                        jobID,
+                        newGraph.getJobID(),
                         2,
                         System.currentTimeMillis(),
                         System.currentTimeMillis() + 3000,
@@ -887,7 +882,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
         // set up the coordinator and validate the initial state
         CheckpointCoordinator coord =
                 new CheckpointCoordinatorBuilder()
-                        .setTasks(newJobVertex1.getTaskVertices())
+                        .setExecutionGraph(newGraph)
                         .setCompletedCheckpointStore(
                                 CompletedCheckpointStore.storeFor(() -> {}, completedCheckpoint))
                         .setTimer(manuallyTriggeredScheduledExecutor)
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 03dbd55..540ad49 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
@@ -32,12 +33,12 @@ import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
-import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration.CheckpointCoordinatorConfigurationBuilder;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
 import org.apache.flink.runtime.state.CheckpointMetadataOutputStream;
@@ -59,12 +60,14 @@ import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAcces
 import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
 import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
+import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.TestLogger;
 import org.apache.flink.util.function.TriFunctionWithException;
 
 import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
 
+import com.sun.istack.Nullable;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
@@ -72,8 +75,6 @@ import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 import org.mockito.verification.VerificationMode;
 
-import javax.annotation.Nullable;
-
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -96,11 +97,8 @@ import java.util.concurrent.Future;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_ASYNC_EXCEPTION;
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED;
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_EXPIRED;
@@ -112,9 +110,8 @@ import static org.junit.Assert.assertNotEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.anyLong;
-import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -129,42 +126,53 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     @Test
     public void testAbortedCheckpointStatsUpdatedAfterFailure() throws Exception {
-        JobID jobID = new JobID();
         testReportStatsAfterFailure(
-                jobID,
                 1L,
-                (coordinator, attemptID, metrics) -> {
-                    coordinator.reportStats(1L, attemptID, metrics);
+                (coordinator, execution, metrics) -> {
+                    coordinator.reportStats(1L, execution.getAttemptId(), metrics);
                     return null;
                 });
     }
 
     @Test
     public void testCheckpointStatsUpdatedAfterFailure() throws Exception {
-        JobID jobID = new JobID();
         testReportStatsAfterFailure(
-                jobID,
                 1L,
-                (coordinator, attemptID, metrics) ->
+                (coordinator, execution, metrics) ->
                         coordinator.receiveAcknowledgeMessage(
                                 new AcknowledgeCheckpoint(
-                                        jobID, attemptID, 1L, metrics, new TaskStateSnapshot()),
+                                        execution.getVertex().getJobId(),
+                                        execution.getAttemptId(),
+                                        1L,
+                                        metrics,
+                                        new TaskStateSnapshot()),
                                 TASK_MANAGER_LOCATION_INFO));
     }
 
     private void testReportStatsAfterFailure(
-            JobID jobID,
             long checkpointId,
             TriFunctionWithException<
                             CheckpointCoordinator,
-                            ExecutionAttemptID,
+                            Execution,
                             CheckpointMetrics,
                             ?,
                             CheckpointException>
                     reportFn)
             throws Exception {
-        ExecutionVertex decliningVertex = mockExecutionVertex(new ExecutionAttemptID());
-        ExecutionVertex lateReportVertex = mockExecutionVertex(new ExecutionAttemptID());
+
+        JobVertexID decliningVertexID = new JobVertexID();
+        JobVertexID lateReportVertexID = new JobVertexID();
+
+        ExecutionGraph executionGraph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(decliningVertexID)
+                        .addJobVertex(lateReportVertexID)
+                        .build();
+
+        ExecutionVertex decliningVertex =
+                executionGraph.getJobVertex(decliningVertexID).getTaskVertices()[0];
+        ExecutionVertex lateReportVertex =
+                executionGraph.getJobVertex(lateReportVertexID).getTaskVertices()[0];
         CheckpointStatsTracker statsTracker =
                 new CheckpointStatsTracker(
                         Integer.MAX_VALUE,
@@ -172,9 +180,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                         new UnregisteredMetricsGroup());
         CheckpointCoordinator coordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobID)
+                        .setExecutionGraph(executionGraph)
                         .setTimer(manuallyTriggeredScheduledExecutor)
-                        .setTasks(decliningVertex, lateReportVertex)
                         .build();
         coordinator.setCheckpointStatsTracker(statsTracker);
 
@@ -190,7 +197,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         coordinator.receiveDeclineMessage(
                 new DeclineCheckpoint(
-                        jobID,
+                        executionGraph.getJobID(),
                         decliningVertex.getCurrentExecutionAttempt().getAttemptId(),
                         checkpointId,
                         new CheckpointException(CHECKPOINT_DECLINED)),
@@ -206,9 +213,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                         .build();
 
         reportFn.apply(
-                coordinator,
-                lateReportVertex.getCurrentExecutionAttempt().getAttemptId(),
-                lateReportedMetrics);
+                coordinator, lateReportVertex.getCurrentExecutionAttempt(), lateReportedMetrics);
 
         assertStatsEqual(
                 checkpointId,
@@ -260,8 +265,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
     public void testScheduleTriggerRequestDuringShutdown() throws Exception {
         ScheduledExecutorService executor = Executors.newSingleThreadScheduledExecutor();
         CheckpointCoordinator coordinator =
-                getCheckpointCoordinator(
-                        new ScheduledExecutorServiceAdapter(executor), ExecutionState.RUNNING);
+                getCheckpointCoordinator(new ScheduledExecutorServiceAdapter(executor));
         coordinator.shutdown();
         executor.shutdownNow();
         coordinator.scheduleTriggerRequest(); // shouldn't fail
@@ -274,9 +278,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
         ScheduledExecutorService executorService = Executors.newSingleThreadScheduledExecutor();
         try {
             int pause = 1000;
-            JobID jobId = new JobID();
-            ExecutionAttemptID attemptId = new ExecutionAttemptID();
-            ExecutionVertex vertex = mockExecutionVertex(attemptId);
+            JobVertexID jobVertexId = new JobVertexID();
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexId)
+                            .setMainThreadExecutor(
+                                    ComponentMainThreadExecutorServiceAdapter
+                                            .forSingleThreadExecutor(
+                                                    new DirectScheduledExecutorService()))
+                            .build();
+
+            ExecutionVertex vertex = graph.getJobVertex(jobVertexId).getTaskVertices()[0];
+            ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
 
             CheckpointCoordinator coordinator =
                     new CheckpointCoordinatorBuilder()
@@ -288,10 +301,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                                             .setMaxConcurrentCheckpoints(1)
                                             .setMinPauseBetweenCheckpoints(pause)
                                             .build())
-                            .setTasksToTrigger(new ExecutionVertex[] {vertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {vertex})
-                            .setTasksToCommitTo(new ExecutionVertex[] {vertex})
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .build();
             coordinator.startCheckpointScheduler();
 
@@ -305,7 +315,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                 Thread.sleep(10);
             }
             coordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptId, 1L), TASK_MANAGER_LOCATION_INFO);
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptId, 1L),
+                    TASK_MANAGER_LOCATION_INFO);
             Thread.sleep(pause / 2);
             assertEquals(0, coordinator.getNumberOfPendingCheckpoints());
             Thread.sleep(pause);
@@ -318,10 +329,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testCheckpointAbortsIfTriggerTasksAreNotExecuted() {
         try {
-
             // set up the coordinator and validate the initial state
-            CheckpointCoordinator checkpointCoordinator =
-                    getCheckpointCoordinator(ExecutionState.CREATED);
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(new JobVertexID())
+                            .addJobVertex(new JobVertexID(), false)
+                            .setTransitToRunning(false)
+                            .build();
+
+            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator(graph);
 
             // nothing should be happening
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -347,8 +363,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testCheckpointAbortsIfTriggerTasksAreFinished() {
         try {
-            CheckpointCoordinator checkpointCoordinator =
-                    getCheckpointCoordinator(ExecutionState.FINISHED);
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2, false)
+                            .build();
+
+            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator(graph);
+            Arrays.stream(graph.getJobVertex(jobVertexID1).getTaskVertices())
+                    .forEach(task -> task.getCurrentExecutionAttempt().markFinished());
 
             // nothing should be happening
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -372,14 +398,85 @@ public class CheckpointCoordinatorTest extends TestLogger {
     }
 
     @Test
-    public void testTriggerAndDeclineCheckpointThenFailureManagerThrowsException() {
-        final JobID jobId = new JobID();
+    public void testCheckpointTriggeredAfterSomeTasksFinishedIfAllowed() throws Exception {
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-        ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, 3, 256)
+                        .addJobVertex(jobVertexID2, 3, 256)
+                        .build();
+        ExecutionJobVertex jobVertex1 = graph.getJobVertex(jobVertexID1);
+        ExecutionJobVertex jobVertex2 = graph.getJobVertex(jobVertexID2);
+
+        jobVertex1.getTaskVertices()[0].getCurrentExecutionAttempt().markFinished();
+        jobVertex1.getTaskVertices()[1].getCurrentExecutionAttempt().markFinished();
+        jobVertex2.getTaskVertices()[1].getCurrentExecutionAttempt().markFinished();
+
+        CheckpointCoordinator checkpointCoordinator =
+                new CheckpointCoordinatorBuilder()
+                        .setExecutionGraph(graph)
+                        .setTimer(manuallyTriggeredScheduledExecutor)
+                        .setAllowCheckpointsAfterTasksFinished(true)
+                        .build();
+
+        CheckpointStatsTracker statsTracker =
+                new CheckpointStatsTracker(
+                        Integer.MAX_VALUE,
+                        CheckpointCoordinatorConfiguration.builder().build(),
+                        new UnregisteredMetricsGroup());
+        checkpointCoordinator.setCheckpointStatsTracker(statsTracker);
+
+        // nothing should be happening
+        assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
+        assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
+
+        // trigger the first checkpoint. this will not fail because we allow checkpointing even with
+        // finished tasks
+        final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+                checkpointCoordinator.triggerCheckpoint(false);
+        manuallyTriggeredScheduledExecutor.triggerAll();
+        assertFalse(checkpointFuture.isDone());
+        assertFalse(checkpointFuture.isCompletedExceptionally());
+
+        // Triggering should succeed
+        assertEquals(1, checkpointCoordinator.getNumberOfPendingCheckpoints());
+        PendingCheckpoint pendingCheckpoint =
+                checkpointCoordinator.getPendingCheckpoints().values().iterator().next();
+        AbstractCheckpointStats checkpointStats =
+                statsTracker
+                        .createSnapshot()
+                        .getHistory()
+                        .getCheckpointById(pendingCheckpoint.getCheckpointID());
+        assertEquals(3, checkpointStats.getNumberOfAcknowledgedSubtasks());
+        for (ExecutionVertex task :
+                Arrays.asList(
+                        jobVertex1.getTaskVertices()[0],
+                        jobVertex1.getTaskVertices()[1],
+                        jobVertex2.getTaskVertices()[1])) {
+
+            // those tasks that are already finished are automatically marked as acknowledged
+            assertNotNull(
+                    checkpointStats.getTaskStateStats(task.getJobvertexId())
+                            .getSubtaskStats()[task.getParallelSubtaskIndex()]);
+        }
+    }
+
+    @Test
+    public void testTriggerAndDeclineCheckpointThenFailureManagerThrowsException()
+            throws Exception {
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .build();
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+        final ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        final ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
         final String errorMsg = "Exceeded checkpoint failure tolerance number!";
 
@@ -387,7 +484,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         // set up the coordinator
         CheckpointCoordinator checkpointCoordinator =
-                getCheckpointCoordinator(jobId, vertex1, vertex2, checkpointFailureManager);
+                getCheckpointCoordinator(graph, checkpointFailureManager);
 
         try {
             // trigger the checkpoint. this should succeed
@@ -408,7 +505,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             // acknowledge from one of the tasks
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID2, checkpointId),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
             assertFalse(checkpoint.isDisposed());
             assertFalse(checkpoint.areTasksFullyAcknowledged());
@@ -416,7 +513,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // decline checkpoint from the other task
             checkpointCoordinator.receiveDeclineMessage(
                     new DeclineCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID1,
                             checkpointId,
                             new CheckpointException(CHECKPOINT_DECLINED)),
@@ -440,13 +537,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testExpiredCheckpointExceedsTolerableFailureNumber() throws Exception {
         // create some mock Execution vertices that receive the checkpoint trigger messages
-        ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID());
-        ExecutionVertex vertex2 = mockExecutionVertex(new ExecutionAttemptID());
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .addJobVertex(new JobVertexID())
+                        .build();
 
         final String errorMsg = "Exceeded checkpoint failure tolerance number!";
         CheckpointFailureManager checkpointFailureManager = getCheckpointFailureManager(errorMsg);
         CheckpointCoordinator checkpointCoordinator =
-                getCheckpointCoordinator(new JobID(), vertex1, vertex2, checkpointFailureManager);
+                getCheckpointCoordinator(graph, checkpointFailureManager);
 
         try {
             checkpointCoordinator.triggerCheckpoint(false);
@@ -485,20 +585,31 @@ public class CheckpointCoordinatorTest extends TestLogger {
         try {
             final CheckpointException checkpointException =
                     new CheckpointException(checkpointFailureReason);
-            final JobID jobId = new JobID();
 
-            // create some mock Execution vertices that receive the checkpoint trigger messages
-            final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-            ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-            ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
+
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2)
+                            .setTaskManagerGateway(gateway)
+                            .build();
+
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
             TestFailJobCallback failJobCallback = new TestFailJobCallback();
             // set up the coordinator and validate the initial state
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
-                            .setTasks(new ExecutionVertex[] {vertex1, vertex2})
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(
                                     CheckpointCoordinatorConfiguration.builder()
                                             .setAlignmentTimeout(Long.MAX_VALUE)
@@ -537,7 +648,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             assertNotNull(checkpoint);
             assertEquals(checkpointId, checkpoint.getCheckpointId());
-            assertEquals(jobId, checkpoint.getJobId());
+            assertEquals(graph.getJobID(), checkpoint.getJobId());
             assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks());
             assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks());
             assertEquals(0, checkpoint.getOperatorStates().size());
@@ -545,20 +656,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertFalse(checkpoint.areTasksFullyAcknowledged());
 
             // check that the vertices received the trigger checkpoint message
-            verify(vertex1.getCurrentExecutionAttempt())
-                    .triggerCheckpoint(
-                            checkpointId,
-                            checkpoint.getCheckpointTimestamp(),
-                            CheckpointOptions.forCheckpointWithDefaultLocation());
-            verify(vertex2.getCurrentExecutionAttempt())
-                    .triggerCheckpoint(
-                            checkpointId,
-                            checkpoint.getCheckpointTimestamp(),
-                            CheckpointOptions.forCheckpointWithDefaultLocation());
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                CheckpointCoordinatorTestingUtils.TriggeredCheckpoint triggeredCheckpoint =
+                        gateway.getOnlyTriggeredCheckpoint(
+                                vertex.getCurrentExecutionAttempt().getAttemptId());
+                assertEquals(checkpointId, triggeredCheckpoint.checkpointId);
+                assertEquals(checkpoint.getCheckpointTimestamp(), triggeredCheckpoint.timestamp);
+                assertEquals(
+                        CheckpointOptions.forCheckpointWithDefaultLocation(),
+                        triggeredCheckpoint.checkpointOptions);
+            }
 
             // acknowledge from one of the tasks
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID2, checkpointId), "Unknown location");
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId),
+                    "Unknown location");
             assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks());
             assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks());
             assertFalse(checkpoint.isDisposed());
@@ -566,14 +678,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             // acknowledge the same task again (should not matter)
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID2, checkpointId), "Unknown location");
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId),
+                    "Unknown location");
             assertFalse(checkpoint.isDisposed());
             assertFalse(checkpoint.areTasksFullyAcknowledged());
 
             // decline checkpoint from the other task, this should cancel the checkpoint
             // and trigger a new one
             checkpointCoordinator.receiveDeclineMessage(
-                    new DeclineCheckpoint(jobId, attemptID1, checkpointId, checkpointException),
+                    new DeclineCheckpoint(
+                            graph.getJobID(), attemptID1, checkpointId, checkpointException),
                     TASK_MANAGER_LOCATION_INFO);
             assertTrue(checkpoint.isDisposed());
 
@@ -587,10 +701,12 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // decline again, nothing should happen
             // decline from the other task, nothing should happen
             checkpointCoordinator.receiveDeclineMessage(
-                    new DeclineCheckpoint(jobId, attemptID1, checkpointId, checkpointException),
+                    new DeclineCheckpoint(
+                            graph.getJobID(), attemptID1, checkpointId, checkpointException),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveDeclineMessage(
-                    new DeclineCheckpoint(jobId, attemptID2, checkpointId, checkpointException),
+                    new DeclineCheckpoint(
+                            graph.getJobID(), attemptID2, checkpointId, checkpointException),
                     TASK_MANAGER_LOCATION_INFO);
             assertTrue(checkpoint.isDisposed());
             assertEquals(1, failJobCallback.getInvokeCounter());
@@ -610,16 +726,25 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testTriggerAndDeclineCheckpointComplex() {
         try {
-            final JobID jobId = new JobID();
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
 
-            // create some mock Execution vertices that receive the checkpoint trigger messages
-            final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-            ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-            ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
-            // set up the coordinator and validate the initial state
-            CheckpointCoordinator checkpointCoordinator =
-                    getCheckpointCoordinator(jobId, vertex1, vertex2);
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2)
+                            .setTaskManagerGateway(gateway)
+                            .build();
+
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator(graph);
 
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
             assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
@@ -653,7 +778,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             assertNotNull(checkpoint1);
             assertEquals(checkpoint1Id, checkpoint1.getCheckpointId());
-            assertEquals(jobId, checkpoint1.getJobId());
+            assertEquals(graph.getJobID(), checkpoint1.getJobId());
             assertEquals(2, checkpoint1.getNumberOfNonAcknowledgedTasks());
             assertEquals(0, checkpoint1.getNumberOfAcknowledgedTasks());
             assertEquals(0, checkpoint1.getOperatorStates().size());
@@ -662,7 +787,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             assertNotNull(checkpoint2);
             assertEquals(checkpoint2Id, checkpoint2.getCheckpointId());
-            assertEquals(jobId, checkpoint2.getJobId());
+            assertEquals(graph.getJobID(), checkpoint2.getJobId());
             assertEquals(2, checkpoint2.getNumberOfNonAcknowledgedTasks());
             assertEquals(0, checkpoint2.getNumberOfAcknowledgedTasks());
             assertEquals(0, checkpoint2.getOperatorStates().size());
@@ -670,38 +795,30 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertFalse(checkpoint2.areTasksFullyAcknowledged());
 
             // check that the vertices received the trigger checkpoint message
-            {
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpoint1Id), any(Long.class), any(CheckpointOptions.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpoint1Id), any(Long.class), any(CheckpointOptions.class));
-            }
-
-            // check that the vertices received the trigger checkpoint message for the second
-            // checkpoint
-            {
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpoint2Id), any(Long.class), any(CheckpointOptions.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpoint2Id), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                List<CheckpointCoordinatorTestingUtils.TriggeredCheckpoint> triggeredCheckpoints =
+                        gateway.getTriggeredCheckpoints(
+                                vertex.getCurrentExecutionAttempt().getAttemptId());
+                assertEquals(2, triggeredCheckpoints.size());
+                assertEquals(checkpoint1Id, triggeredCheckpoints.get(0).checkpointId);
+                assertEquals(checkpoint2Id, triggeredCheckpoints.get(1).checkpointId);
             }
 
             // decline checkpoint from one of the tasks, this should cancel the checkpoint
             checkpointCoordinator.receiveDeclineMessage(
                     new DeclineCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID1,
                             checkpoint1Id,
                             new CheckpointException(CHECKPOINT_DECLINED)),
                     TASK_MANAGER_LOCATION_INFO);
-            verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointAborted(eq(checkpoint1Id), any(Long.class));
-            verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointAborted(eq(checkpoint1Id), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                assertEquals(
+                        checkpoint1Id,
+                        gateway.getOnlyNotifiedAbortedCheckpoint(
+                                        vertex.getCurrentExecutionAttempt().getAttemptId())
+                                .checkpointId);
+            }
 
             assertTrue(checkpoint1.isDisposed());
 
@@ -724,7 +841,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             assertNotNull(checkpointNew);
             assertEquals(checkpointIdNew, checkpointNew.getCheckpointId());
-            assertEquals(jobId, checkpointNew.getJobId());
+            assertEquals(graph.getJobID(), checkpointNew.getJobId());
             assertEquals(2, checkpointNew.getNumberOfNonAcknowledgedTasks());
             assertEquals(0, checkpointNew.getNumberOfAcknowledgedTasks());
             assertEquals(0, checkpointNew.getOperatorStates().size());
@@ -736,14 +853,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // decline from the other task, nothing should happen
             checkpointCoordinator.receiveDeclineMessage(
                     new DeclineCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID1,
                             checkpoint1Id,
                             new CheckpointException(CHECKPOINT_DECLINED)),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveDeclineMessage(
                     new DeclineCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID2,
                             checkpoint1Id,
                             new CheckpointException(CHECKPOINT_DECLINED)),
@@ -751,10 +868,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertTrue(checkpoint1.isDisposed());
 
             // will not notify abort message again
-            verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointAborted(eq(checkpoint1Id), any(Long.class));
-            verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointAborted(eq(checkpoint1Id), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                assertEquals(
+                        1,
+                        gateway.getNotifiedAbortedCheckpoints(
+                                        vertex.getCurrentExecutionAttempt().getAttemptId())
+                                .size());
+            }
 
             checkpointCoordinator.shutdown();
         } catch (Exception e) {
@@ -766,17 +886,25 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testTriggerAndConfirmSimpleCheckpoint() {
         try {
-            final JobID jobId = new JobID();
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
 
-            // create some mock Execution vertices that receive the checkpoint trigger messages
-            final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-            ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-            ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-            // set up the coordinator and validate the initial state
-            CheckpointCoordinator checkpointCoordinator =
-                    getCheckpointCoordinator(jobId, vertex1, vertex2);
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2)
+                            .setTaskManagerGateway(gateway)
+                            .build();
+
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+            CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator(graph);
 
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
             assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
@@ -805,7 +933,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             assertNotNull(checkpoint);
             assertEquals(checkpointId, checkpoint.getCheckpointId());
-            assertEquals(jobId, checkpoint.getJobId());
+            assertEquals(graph.getJobID(), checkpoint.getJobId());
             assertEquals(2, checkpoint.getNumberOfNonAcknowledgedTasks());
             assertEquals(0, checkpoint.getNumberOfAcknowledgedTasks());
             assertEquals(0, checkpoint.getOperatorStates().size());
@@ -813,17 +941,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertFalse(checkpoint.areTasksFullyAcknowledged());
 
             // check that the vertices received the trigger checkpoint message
-            {
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
             }
 
-            OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-            OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+            OperatorID opID1 =
+                    vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
+            OperatorID opID2 =
+                    vertex2.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
             TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class);
             TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class);
             OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
@@ -836,7 +963,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // acknowledge from one of the tasks
             AcknowledgeCheckpoint acknowledgeCheckpoint1 =
                     new AcknowledgeCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID2,
                             checkpointId,
                             new CheckpointMetrics(),
@@ -860,7 +987,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // acknowledge the other task.
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             attemptID1,
                             checkpointId,
                             new CheckpointMetrics(),
@@ -887,23 +1014,22 @@ public class CheckpointCoordinatorTest extends TestLogger {
             }
 
             // validate that the relevant tasks got a confirmation message
-            {
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId,
+                        gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
             }
 
             CompletedCheckpoint success = checkpointCoordinator.getSuccessfulCheckpoints().get(0);
-            assertEquals(jobId, success.getJobId());
+            assertEquals(graph.getJobID(), success.getJobId());
             assertEquals(checkpoint.getCheckpointId(), success.getCheckpointID());
             assertEquals(2, success.getOperatorStates().size());
 
             // ---------------
             // trigger another checkpoint and see that this one replaces the other checkpoint
             // ---------------
+            gateway.resetCount();
             checkpointCoordinator.triggerCheckpoint(false);
             manuallyTriggeredScheduledExecutor.triggerAll();
 
@@ -915,10 +1041,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
                             .next()
                             .getKey();
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID1, checkpointIdNew),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointIdNew),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID2, checkpointIdNew),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointIdNew),
                     TASK_MANAGER_LOCATION_INFO);
 
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -927,23 +1053,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             CompletedCheckpoint successNew =
                     checkpointCoordinator.getSuccessfulCheckpoints().get(0);
-            assertEquals(jobId, successNew.getJobId());
+            assertEquals(graph.getJobID(), successNew.getJobId());
             assertEquals(checkpointIdNew, successNew.getCheckpointID());
             assertTrue(successNew.getOperatorStates().isEmpty());
 
             // validate that the relevant tasks got a confirmation message
-            {
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointIdNew), any(Long.class), any(CheckpointOptions.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .triggerCheckpoint(
-                                eq(checkpointIdNew), any(Long.class), any(CheckpointOptions.class));
-
-                verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                        .notifyCheckpointComplete(eq(checkpointIdNew), any(Long.class));
-                verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                        .notifyCheckpointComplete(eq(checkpointIdNew), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointIdNew,
+                        gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+                assertEquals(
+                        checkpointIdNew,
+                        gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
             }
 
             checkpointCoordinator.shutdown();
@@ -956,41 +1078,37 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testMultipleConcurrentCheckpoints() {
         try {
-            final JobID jobId = new JobID();
-
-            // create some mock execution vertices
-
-            final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
-
-            final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID3 = new ExecutionAttemptID();
-
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-            ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
-            ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2);
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
+            JobVertexID jobVertexID3 = new JobVertexID();
+
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2)
+                            .addJobVertex(jobVertexID3, false)
+                            .setTaskManagerGateway(gateway)
+                            .build();
 
-            ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1);
-            ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2);
-            ExecutionVertex ackVertex3 = mockExecutionVertex(ackAttemptID3);
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+            ExecutionVertex vertex3 = graph.getJobVertex(jobVertexID3).getTaskVertices()[0];
 
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID3 = vertex3.getCurrentExecutionAttempt().getAttemptId();
 
             // set up the coordinator and validate the initial state
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(
                                     CheckpointCoordinatorConfiguration.builder()
                                             .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
                                             .build())
-                            .setTasksToTrigger(
-                                    new ExecutionVertex[] {triggerVertex1, triggerVertex2})
-                            .setTasksToWaitFor(
-                                    new ExecutionVertex[] {ackVertex1, ackVertex2, ackVertex3})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -1012,20 +1130,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
             long checkpointId1 = pending1.getCheckpointId();
 
             // trigger messages should have been sent
-            verify(triggerVertex1.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId1), any(Long.class), any(CheckpointOptions.class));
-            verify(triggerVertex2.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId1), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId1, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+            }
 
             // acknowledge one of the three tasks
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID2, checkpointId1),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId1),
                     TASK_MANAGER_LOCATION_INFO);
 
             // start the second checkpoint
             // trigger the first checkpoint. this should succeed
+            gateway.resetCount();
             final CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
                     checkpointCoordinator.triggerCheckpoint(false);
             manuallyTriggeredScheduledExecutor.triggerAll();
@@ -1045,26 +1163,25 @@ public class CheckpointCoordinatorTest extends TestLogger {
             long checkpointId2 = pending2.getCheckpointId();
 
             // trigger messages should have been sent
-            verify(triggerVertex1.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId2), any(Long.class), any(CheckpointOptions.class));
-            verify(triggerVertex2.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId2), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId2, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+            }
 
             // we acknowledge the remaining two tasks from the first
             // checkpoint and two tasks from the second checkpoint
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID3, checkpointId1),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID3, checkpointId1),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID1, checkpointId2),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointId2),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID1, checkpointId1),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointId1),
                     TASK_MANAGER_LOCATION_INFO);
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID2, checkpointId2),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId2),
                     TASK_MANAGER_LOCATION_INFO);
 
             // now, the first checkpoint should be confirmed
@@ -1073,12 +1190,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertTrue(pending1.isDisposed());
 
             // the first confirm message should be out
-            verify(commitVertex.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointId1), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2, vertex3)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId1,
+                        gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
+            }
 
             // send the last remaining ack for the second checkpoint
+            gateway.resetCount();
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID3, checkpointId2),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID3, checkpointId2),
                     TASK_MANAGER_LOCATION_INFO);
 
             // now, the second checkpoint should be confirmed
@@ -1087,20 +1209,24 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertTrue(pending2.isDisposed());
 
             // the second commit message should be out
-            verify(commitVertex.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointId2), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2, vertex3)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId2,
+                        gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
+            }
 
             // validate the committed checkpoints
             List<CompletedCheckpoint> scs = checkpointCoordinator.getSuccessfulCheckpoints();
 
             CompletedCheckpoint sc1 = scs.get(0);
             assertEquals(checkpointId1, sc1.getCheckpointID());
-            assertEquals(jobId, sc1.getJobId());
+            assertEquals(graph.getJobID(), sc1.getJobId());
             assertTrue(sc1.getOperatorStates().isEmpty());
 
             CompletedCheckpoint sc2 = scs.get(1);
             assertEquals(checkpointId2, sc2.getCheckpointID());
-            assertEquals(jobId, sc2.getJobId());
+            assertEquals(graph.getJobID(), sc2.getJobId());
             assertTrue(sc2.getOperatorStates().isEmpty());
 
             checkpointCoordinator.shutdown();
@@ -1113,42 +1239,39 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testSuccessfulCheckpointSubsumesUnsuccessful() {
         try {
-            final JobID jobId = new JobID();
-
-            // create some mock execution vertices
-            final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
-
-            final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID3 = new ExecutionAttemptID();
-
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-            ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
-            ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2);
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
+            JobVertexID jobVertexID3 = new JobVertexID();
+
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2)
+                            .addJobVertex(jobVertexID3, false)
+                            .setTaskManagerGateway(gateway)
+                            .build();
 
-            ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1);
-            ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2);
-            ExecutionVertex ackVertex3 = mockExecutionVertex(ackAttemptID3);
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+            ExecutionVertex vertex3 = graph.getJobVertex(jobVertexID3).getTaskVertices()[0];
 
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+            ExecutionAttemptID attemptID3 = vertex3.getCurrentExecutionAttempt().getAttemptId();
 
             // set up the coordinator and validate the initial state
             final StandaloneCompletedCheckpointStore completedCheckpointStore =
                     new StandaloneCompletedCheckpointStore(10);
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(
                                     CheckpointCoordinatorConfiguration.builder()
                                             .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
                                             .build())
-                            .setTasksToTrigger(
-                                    new ExecutionVertex[] {triggerVertex1, triggerVertex2})
-                            .setTasksToWaitFor(
-                                    new ExecutionVertex[] {ackVertex1, ackVertex2, ackVertex3})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(completedCheckpointStore)
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -1170,16 +1293,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
             long checkpointId1 = pending1.getCheckpointId();
 
             // trigger messages should have been sent
-            verify(triggerVertex1.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId1), any(Long.class), any(CheckpointOptions.class));
-            verify(triggerVertex2.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId1), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId1, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+            }
 
-            OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
-            OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
-            OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
+            OperatorID opID1 =
+                    vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
+            OperatorID opID2 =
+                    vertex2.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
+            OperatorID opID3 =
+                    vertex3.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
 
             TaskStateSnapshot taskOperatorSubtaskStates11 = spy(new TaskStateSnapshot());
             TaskStateSnapshot taskOperatorSubtaskStates12 = spy(new TaskStateSnapshot());
@@ -1195,15 +1320,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // acknowledge one of the three tasks
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID2,
+                            graph.getJobID(),
+                            attemptID2,
                             checkpointId1,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates12),
                     TASK_MANAGER_LOCATION_INFO);
 
             // start the second checkpoint
-            // trigger the first checkpoint. this should succeed
+            gateway.resetCount();
             final CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
                     checkpointCoordinator.triggerCheckpoint(false);
             manuallyTriggeredScheduledExecutor.triggerAll();
@@ -1235,20 +1360,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
             taskOperatorSubtaskStates23.putSubtaskStateByOperatorID(opID3, subtaskState23);
 
             // trigger messages should have been sent
-            verify(triggerVertex1.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId2), any(Long.class), any(CheckpointOptions.class));
-            verify(triggerVertex2.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointId2), any(Long.class), any(CheckpointOptions.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId2, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+            }
 
             // we acknowledge one more task from the first checkpoint and the second
             // checkpoint completely. The second checkpoint should then subsume the first checkpoint
 
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID3,
+                            graph.getJobID(),
+                            attemptID3,
                             checkpointId2,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates23),
@@ -1256,8 +1380,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID1,
+                            graph.getJobID(),
+                            attemptID1,
                             checkpointId2,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates21),
@@ -1265,8 +1389,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID1,
+                            graph.getJobID(),
+                            attemptID1,
                             checkpointId1,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates11),
@@ -1274,8 +1398,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID2,
+                            graph.getJobID(),
+                            attemptID2,
                             checkpointId2,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates22),
@@ -1303,18 +1427,22 @@ public class CheckpointCoordinatorTest extends TestLogger {
             List<CompletedCheckpoint> scs = checkpointCoordinator.getSuccessfulCheckpoints();
             CompletedCheckpoint success = scs.get(0);
             assertEquals(checkpointId2, success.getCheckpointID());
-            assertEquals(jobId, success.getJobId());
+            assertEquals(graph.getJobID(), success.getJobId());
             assertEquals(3, success.getOperatorStates().size());
 
             // the first confirm message should be out
-            verify(commitVertex.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointId2), any(Long.class));
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2, vertex3)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(
+                        checkpointId2,
+                        gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
+            }
 
             // send the last remaining ack for the first checkpoint. This should not do anything
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID3,
+                            graph.getJobID(),
+                            attemptID3,
                             checkpointId1,
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates13),
@@ -1338,31 +1466,28 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testCheckpointTimeoutIsolated() {
         try {
-            final JobID jobId = new JobID();
-
-            // create some mock execution vertices
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
 
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-            final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID();
-
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2, false)
+                            .setTaskManagerGateway(gateway)
+                            .build();
 
-            ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1);
-            ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2);
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+            ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
 
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
 
             // set up the coordinator
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex1, ackVertex2})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
+                            .setExecutionGraph(graph)
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -1378,7 +1503,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                     checkpointCoordinator.getPendingCheckpoints().values().iterator().next();
             assertFalse(checkpoint.isDisposed());
 
-            OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
+            OperatorID opID1 =
+                    vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
 
             TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot());
             OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
@@ -1386,8 +1512,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             checkpointCoordinator.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jobId,
-                            ackAttemptID1,
+                            graph.getJobID(),
+                            attemptID1,
                             checkpoint.getCheckpointId(),
                             new CheckpointMetrics(),
                             taskOperatorSubtaskStates1),
@@ -1403,8 +1529,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
             verify(subtaskState1, times(1)).discardState();
 
             // no confirm message must have been sent
-            verify(commitVertex.getCurrentExecutionAttempt(), times(0))
-                    .notifyCheckpointComplete(anyLong(), anyLong());
+            for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+                ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+                assertEquals(0, gateway.getNotifiedCompletedCheckpoints(attemptId).size());
+            }
 
             checkpointCoordinator.shutdown();
         } catch (Exception e) {
@@ -1416,26 +1544,27 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testHandleMessagesForNonExistingCheckpoints() {
         try {
-            final JobID jobId = new JobID();
-
             // create some mock execution vertices and trigger some checkpoint
+            JobVertexID jobVertexID1 = new JobVertexID();
+            JobVertexID jobVertexID2 = new JobVertexID();
+
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID();
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .addJobVertex(jobVertexID2, false)
+                            .setTaskManagerGateway(gateway)
+                            .build();
+
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
 
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-            ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1);
-            ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2);
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
 
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex1, ackVertex2})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
+                            .setExecutionGraph(graph)
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -1454,17 +1583,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             // wrong job id
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(new JobID(), ackAttemptID1, checkpointId),
+                    new AcknowledgeCheckpoint(new JobID(), attemptID1, checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
 
             // unknown checkpoint
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID1, 1L),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, 1L),
                     TASK_MANAGER_LOCATION_INFO);
 
             // unknown ack vertex
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId),
+                    new AcknowledgeCheckpoint(
+                            graph.getJobID(), new ExecutionAttemptID(), checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
 
             checkpointCoordinator.shutdown();
@@ -1484,29 +1614,33 @@ public class CheckpointCoordinatorTest extends TestLogger {
      */
     @Test
     public void testStateCleanupForLateOrUnknownMessages() throws Exception {
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        final ExecutionAttemptID triggerAttemptId = new ExecutionAttemptID();
-        final ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptId);
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-        final ExecutionAttemptID ackAttemptId1 = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptId1);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2, false)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
 
-        final ExecutionAttemptID ackAttemptId2 = new ExecutionAttemptID();
-        final ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptId2);
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
         CheckpointCoordinatorConfiguration chkConfig =
-                new CheckpointCoordinatorConfigurationBuilder()
+                new CheckpointCoordinatorConfiguration.CheckpointCoordinatorConfigurationBuilder()
                         .setMaxConcurrentCheckpoints(1)
                         .build();
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(chkConfig)
-                        .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                        .setTasksToWaitFor(
-                                new ExecutionVertex[] {triggerVertex, ackVertex1, ackVertex2})
-                        .setTasksToCommitTo(new ExecutionVertex[0])
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
 
@@ -1522,7 +1656,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         long checkpointId = pendingCheckpoint.getCheckpointId();
 
-        OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId());
+        OperatorID opIDtrigger =
+                vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
 
         TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot());
         OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class);
@@ -1532,8 +1667,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // acknowledge the first trigger vertex
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
-                        triggerAttemptId,
+                        graph.getJobID(),
+                        attemptID1,
                         checkpointId,
                         new CheckpointMetrics(),
                         taskOperatorSubtaskStatesTrigger),
@@ -1547,7 +1682,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // receive an acknowledge message for an unknown vertex
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
+                        graph.getJobID(),
                         new ExecutionAttemptID(),
                         checkpointId,
                         new CheckpointMetrics(),
@@ -1576,8 +1711,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
         TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class);
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
-                        triggerAttemptId,
+                        graph.getJobID(),
+                        attemptID1,
                         checkpointId,
                         new CheckpointMetrics(),
                         triggerSubtaskState),
@@ -1590,8 +1725,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
         reset(subtaskStateTrigger);
         checkpointCoordinator.receiveDeclineMessage(
                 new DeclineCheckpoint(
-                        jobId,
-                        ackAttemptId1,
+                        graph.getJobID(),
+                        attemptID1,
                         checkpointId,
                         new CheckpointException(CHECKPOINT_DECLINED)),
                 TASK_MANAGER_LOCATION_INFO);
@@ -1606,8 +1741,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // late acknowledge message from the second ack vertex
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
-                        ackAttemptId2,
+                        graph.getJobID(),
+                        attemptID2,
                         checkpointId,
                         new CheckpointMetrics(),
                         ackSubtaskState),
@@ -1635,7 +1770,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // receive an acknowledge message for an unknown vertex
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
+                        graph.getJobID(),
                         new ExecutionAttemptID(),
                         checkpointId,
                         new CheckpointMetrics(),
@@ -1663,17 +1798,27 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     @Test
     public void testTriggerAndConfirmSimpleSavepoint() throws Exception {
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-        ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
-        CheckpointCoordinator checkpointCoordinator =
-                getCheckpointCoordinator(jobId, vertex1, vertex2);
+        CheckpointCoordinator checkpointCoordinator = getCheckpointCoordinator(graph);
 
         assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
         assertEquals(0, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1694,7 +1839,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         assertNotNull(pending);
         assertEquals(checkpointId, pending.getCheckpointId());
-        assertEquals(jobId, pending.getJobId());
+        assertEquals(graph.getJobID(), pending.getJobId());
         assertEquals(2, pending.getNumberOfNonAcknowledgedTasks());
         assertEquals(0, pending.getNumberOfAcknowledgedTasks());
         assertEquals(0, pending.getOperatorStates().size());
@@ -1716,7 +1861,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // acknowledge from one of the tasks
         AcknowledgeCheckpoint acknowledgeCheckpoint2 =
                 new AcknowledgeCheckpoint(
-                        jobId,
+                        graph.getJobID(),
                         attemptID2,
                         checkpointId,
                         new CheckpointMetrics(),
@@ -1739,7 +1884,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // acknowledge the other task.
         checkpointCoordinator.receiveAcknowledgeMessage(
                 new AcknowledgeCheckpoint(
-                        jobId,
+                        graph.getJobID(),
                         attemptID1,
                         checkpointId,
                         new CheckpointMetrics(),
@@ -1756,11 +1901,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
         assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
 
         // validate that the relevant tasks got a confirmation message
-        {
-            verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointId), any(Long.class));
-            verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointId), any(Long.class));
+        for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+            ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+            assertEquals(checkpointId, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
         }
 
         // validate that the shared states are registered
@@ -1770,13 +1913,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
         }
 
         CompletedCheckpoint success = checkpointCoordinator.getSuccessfulCheckpoints().get(0);
-        assertEquals(jobId, success.getJobId());
+        assertEquals(graph.getJobID(), success.getJobId());
         assertEquals(pending.getCheckpointId(), success.getCheckpointID());
         assertEquals(2, success.getOperatorStates().size());
 
         // ---------------
         // trigger another checkpoint and see that this one replaces the other checkpoint
         // ---------------
+        gateway.resetCount();
         savepointFuture = checkpointCoordinator.triggerSavepoint(savepointDir);
         manuallyTriggeredScheduledExecutor.triggerAll();
         assertFalse(savepointFuture.isDone());
@@ -1784,17 +1928,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
         long checkpointIdNew =
                 checkpointCoordinator.getPendingCheckpoints().entrySet().iterator().next().getKey();
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID1, checkpointIdNew),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointIdNew),
                 TASK_MANAGER_LOCATION_INFO);
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID2, checkpointIdNew),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointIdNew),
                 TASK_MANAGER_LOCATION_INFO);
 
         assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
         assertEquals(1, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
 
         CompletedCheckpoint successNew = checkpointCoordinator.getSuccessfulCheckpoints().get(0);
-        assertEquals(jobId, successNew.getJobId());
+        assertEquals(graph.getJobID(), successNew.getJobId());
         assertEquals(checkpointIdNew, successNew.getCheckpointID());
         assertTrue(successNew.getOperatorStates().isEmpty());
         assertNotNull(savepointFuture.get());
@@ -1804,18 +1948,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
         verify(subtaskState2, never()).discardState();
 
         // validate that the relevant tasks got a confirmation message
-        {
-            verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointIdNew), any(Long.class), any(CheckpointOptions.class));
-            verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                    .triggerCheckpoint(
-                            eq(checkpointIdNew), any(Long.class), any(CheckpointOptions.class));
-
-            verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointIdNew), any(Long.class));
-            verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                    .notifyCheckpointComplete(eq(checkpointIdNew), any(Long.class));
+        for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+            ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+            assertEquals(
+                    checkpointIdNew, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+            assertEquals(
+                    checkpointIdNew,
+                    gateway.getOnlyNotifiedCompletedCheckpoint(attemptId).checkpointId);
         }
 
         checkpointCoordinator.shutdown();
@@ -1829,25 +1968,31 @@ public class CheckpointCoordinatorTest extends TestLogger {
      */
     @Test
     public void testSavepointsAreNotSubsumed() throws Exception {
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-        ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
         StandaloneCheckpointIDCounter counter = new StandaloneCheckpointIDCounter();
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(
                                 CheckpointCoordinatorConfiguration.builder()
                                         .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
                                         .build())
-                        .setTasks(new ExecutionVertex[] {vertex1, vertex2})
                         .setCheckpointIDCounter(counter)
                         .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(10))
                         .setTimer(manuallyTriggeredScheduledExecutor)
@@ -1878,10 +2023,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         // 2nd checkpoint should subsume the 1st checkpoint, but not the savepoint
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID1, checkpointId2),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointId2),
                 TASK_MANAGER_LOCATION_INFO);
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID2, checkpointId2),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, checkpointId2),
                 TASK_MANAGER_LOCATION_INFO);
 
         assertEquals(1, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -1905,14 +2050,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         // 2nd savepoint should subsume the last checkpoint, but not the 1st savepoint
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID1, savepointId2),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, savepointId2),
                 TASK_MANAGER_LOCATION_INFO);
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID2, savepointId2),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, savepointId2),
                 TASK_MANAGER_LOCATION_INFO);
 
         assertEquals(1, checkpointCoordinator.getNumberOfPendingCheckpoints());
         assertEquals(2, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
+
         assertFalse(checkpointCoordinator.getPendingCheckpoints().get(savepointId1).isDisposed());
 
         assertFalse(savepointFuture1.isDone());
@@ -1920,10 +2066,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         // Ack first savepoint
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID1, savepointId1),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, savepointId1),
                 TASK_MANAGER_LOCATION_INFO);
         checkpointCoordinator.receiveAcknowledgeMessage(
-                new AcknowledgeCheckpoint(jobId, attemptID2, savepointId1),
+                new AcknowledgeCheckpoint(graph.getJobID(), attemptID2, savepointId1),
                 TASK_MANAGER_LOCATION_INFO);
 
         assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
@@ -1933,39 +2079,24 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     private void testMaxConcurrentAttempts(int maxConcurrentAttempts) {
         try {
-            final JobID jobId = new JobID();
+            JobVertexID jobVertexID1 = new JobVertexID();
 
-            // create some mock execution vertices and trigger some checkpoint
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-            ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
-
-            final AtomicInteger numCalls = new AtomicInteger();
-
-            final Execution execution = triggerVertex.getCurrentExecutionAttempt();
-
-            doAnswer(
-                            invocation -> {
-                                numCalls.incrementAndGet();
-                                return null;
-                            })
-                    .when(execution)
-                    .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
-
-            doAnswer(
-                            invocation -> {
-                                numCalls.incrementAndGet();
-                                return null;
-                            })
-                    .when(execution)
-                    .notifyCheckpointComplete(anyLong(), anyLong());
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .setTaskManagerGateway(gateway)
+                            .build();
+
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
 
             CheckpointCoordinatorConfiguration chkConfig =
-                    new CheckpointCoordinatorConfigurationBuilder()
+                    new CheckpointCoordinatorConfiguration
+                                    .CheckpointCoordinatorConfigurationBuilder()
                             .setCheckpointInterval(10) // periodic interval is 10 ms
                             .setCheckpointTimeout(200000) // timeout is very long (200 s)
                             .setMinPauseBetweenCheckpoints(0L) // no extra delay
@@ -1973,11 +2104,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                             .build();
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(chkConfig)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -1989,29 +2117,30 @@ public class CheckpointCoordinatorTest extends TestLogger {
                 manuallyTriggeredScheduledExecutor.triggerAll();
             }
 
-            assertEquals(maxConcurrentAttempts, numCalls.get());
-
-            verify(triggerVertex.getCurrentExecutionAttempt(), times(maxConcurrentAttempts))
-                    .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+            assertEquals(maxConcurrentAttempts, gateway.getTriggeredCheckpoints(attemptID1).size());
+            assertEquals(0, gateway.getNotifiedCompletedCheckpoints(attemptID1).size());
 
             // now, once we acknowledge one checkpoint, it should trigger the next one
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID, 1L), TASK_MANAGER_LOCATION_INFO);
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, 1L),
+                    TASK_MANAGER_LOCATION_INFO);
 
             final Collection<ScheduledFuture<?>> periodicScheduledTasks =
                     manuallyTriggeredScheduledExecutor.getPeriodicScheduledTask();
             assertEquals(1, periodicScheduledTasks.size());
-            final ScheduledFuture scheduledFuture = periodicScheduledTasks.iterator().next();
 
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
             manuallyTriggeredScheduledExecutor.triggerAll();
 
-            assertEquals(maxConcurrentAttempts + 1, numCalls.get());
+            assertEquals(
+                    maxConcurrentAttempts + 1, gateway.getTriggeredCheckpoints(attemptID1).size());
 
             // no further checkpoints should happen
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
             manuallyTriggeredScheduledExecutor.triggerAll();
-            assertEquals(maxConcurrentAttempts + 1, numCalls.get());
+
+            assertEquals(
+                    maxConcurrentAttempts + 1, gateway.getTriggeredCheckpoints(attemptID1).size());
 
             checkpointCoordinator.shutdown();
         } catch (Exception e) {
@@ -2024,19 +2153,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
     public void testMaxConcurrentAttempsWithSubsumption() {
         try {
             final int maxConcurrentAttempts = 2;
-            final JobID jobId = new JobID();
+            JobVertexID jobVertexID1 = new JobVertexID();
 
-            // create some mock execution vertices and trigger some checkpoint
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .build();
 
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-            ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+
+            ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
 
             CheckpointCoordinatorConfiguration chkConfig =
-                    new CheckpointCoordinatorConfigurationBuilder()
+                    new CheckpointCoordinatorConfiguration
+                                    .CheckpointCoordinatorConfigurationBuilder()
                             .setCheckpointInterval(10) // periodic interval is 10 ms
                             .setCheckpointTimeout(200000) // timeout is very long (200 s)
                             .setMinPauseBetweenCheckpoints(0L) // no extra delay
@@ -2044,11 +2174,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                             .build();
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(chkConfig)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -2070,7 +2197,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
             // and allow two more checkpoints to be triggered
             // now, once we acknowledge one checkpoint, it should trigger the next one
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, ackAttemptID, 2L), TASK_MANAGER_LOCATION_INFO);
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, 2L),
+                    TASK_MANAGER_LOCATION_INFO);
 
             // after a while, there should be the new checkpoints
             do {
@@ -2094,24 +2222,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testPeriodicSchedulingWithInactiveTasks() {
         try {
-            final JobID jobId = new JobID();
-
-            // create some mock execution vertices and trigger some checkpoint
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
+            JobVertexID jobVertexID1 = new JobVertexID();
 
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-            ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID1)
+                            .setTransitToRunning(false)
+                            .build();
 
-            final AtomicReference<ExecutionState> currentState =
-                    new AtomicReference<>(ExecutionState.CREATED);
-            when(triggerVertex.getCurrentExecutionAttempt().getState())
-                    .thenAnswer(invocation -> currentState.get());
+            ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
 
             CheckpointCoordinatorConfiguration chkConfig =
-                    new CheckpointCoordinatorConfigurationBuilder()
+                    new CheckpointCoordinatorConfiguration
+                                    .CheckpointCoordinatorConfigurationBuilder()
                             .setCheckpointInterval(10) // periodic interval is 10 ms
                             .setCheckpointTimeout(200000) // timeout is very long (200 s)
                             .setMinPauseBetweenCheckpoints(0) // no extra delay
@@ -2119,11 +2242,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
                             .build();
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jobId)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(chkConfig)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -2136,7 +2256,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
             assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
 
             // now move the state to RUNNING
-            currentState.set(ExecutionState.RUNNING);
+            vertex1.getCurrentExecutionAttempt().transitionState(ExecutionState.RUNNING);
 
             // the coordinator should start checkpointing now
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
@@ -2152,24 +2272,30 @@ public class CheckpointCoordinatorTest extends TestLogger {
     /** Tests that the savepoints can be triggered concurrently. */
     @Test
     public void testConcurrentSavepoints() throws Exception {
-        JobID jobId = new JobID();
         int numSavepoints = 5;
 
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
+        JobVertexID jobVertexID1 = new JobVertexID();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
 
         StandaloneCheckpointIDCounter checkpointIDCounter = new StandaloneCheckpointIDCounter();
 
         CheckpointCoordinatorConfiguration chkConfig =
-                new CheckpointCoordinatorConfigurationBuilder()
+                new CheckpointCoordinatorConfiguration.CheckpointCoordinatorConfigurationBuilder()
                         .setMaxConcurrentCheckpoints(
                                 1) // max one checkpoint at a time => should not affect savepoints
                         .build();
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(chkConfig)
-                        .setTasks(new ExecutionVertex[] {vertex1})
                         .setCheckpointIDCounter(checkpointIDCounter)
                         .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                         .setTimer(manuallyTriggeredScheduledExecutor)
@@ -2195,7 +2321,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         long checkpointId = checkpointIDCounter.getLast();
         for (int i = 0; i < numSavepoints; i++, checkpointId--) {
             checkpointCoordinator.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jobId, attemptID1, checkpointId),
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID1, checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
         }
 
@@ -2209,7 +2335,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
     @Test
     public void testMinDelayBetweenSavepoints() throws Exception {
         CheckpointCoordinatorConfiguration chkConfig =
-                new CheckpointCoordinatorConfigurationBuilder()
+                new CheckpointCoordinatorConfiguration.CheckpointCoordinatorConfigurationBuilder()
                         .setMinPauseBetweenCheckpoints(
                                 100000000L) // very long min delay => should not affect savepoints
                         .setMaxConcurrentCheckpoints(1)
@@ -2237,14 +2363,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
     public void testExternalizedCheckpoints() throws Exception {
         try {
 
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(new JobVertexID())
+                            .build();
+
             // set up the coordinator and validate the initial state
             CheckpointCoordinatorConfiguration chkConfig =
-                    new CheckpointCoordinatorConfigurationBuilder()
+                    new CheckpointCoordinatorConfiguration
+                                    .CheckpointCoordinatorConfigurationBuilder()
                             .setCheckpointRetentionPolicy(
                                     CheckpointRetentionPolicy.RETAIN_ON_FAILURE)
                             .build();
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(chkConfig)
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
@@ -2483,7 +2616,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
     /** Tests that the pending checkpoint stats callbacks are created. */
     @Test
     public void testCheckpointStatsTrackerPendingCheckpointCallback() throws Exception {
-
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
@@ -2551,23 +2683,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     @Test
     public void testSharedStateRegistrationOnRestore() throws Exception {
-
-        final JobID jobId = new JobID();
-
-        final JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID1 = new JobVertexID();
 
         int parallelism1 = 2;
         int maxParallelism1 = 4;
 
-        final ExecutionJobVertex jobVertex1 =
-                mockExecutionJobVertex(jobVertexID1, parallelism1, maxParallelism1);
-
-        List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1);
-
-        allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1, parallelism1, maxParallelism1)
+                        .build();
 
-        ExecutionVertex[] arrayExecutionVertices =
-                allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+        ExecutionJobVertex jobVertex1 = graph.getJobVertex(jobVertexID1);
 
         EmbeddedCompletedCheckpointStore store = new EmbeddedCompletedCheckpointStore(10);
 
@@ -2576,8 +2702,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
-                        .setTasks(arrayExecutionVertices)
+                        .setExecutionGraph(graph)
                         .setCompletedCheckpointStore(store)
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .setSharedStateRegistryFactory(
@@ -2596,7 +2721,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         for (int i = 0; i < numCheckpoints; ++i) {
             performIncrementalCheckpoint(
-                    jobId, checkpointCoordinator, jobVertex1, keyGroupPartitions1, i);
+                    graph.getJobID(), checkpointCoordinator, jobVertex1, keyGroupPartitions1, i);
         }
 
         List<CompletedCheckpoint> completedCheckpoints =
@@ -2729,19 +2854,25 @@ public class CheckpointCoordinatorTest extends TestLogger {
         final Tuple2<Integer, Throwable> invocationCounterAndException = Tuple2.of(0, null);
         final Throwable expectedRootCause = new IOException("Custom-Exception");
 
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
 
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        final ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-        final ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
         final CheckpointCoordinator coordinator =
                 getCheckpointCoordinator(
-                        jobId,
-                        vertex1,
-                        vertex2,
+                        graph,
                         new CheckpointFailureManager(
                                 0,
                                 new CheckpointFailureManager.FailJobCallback() {
@@ -2764,7 +2895,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         manuallyTriggeredScheduledExecutor.triggerAll();
         final PendingCheckpoint syncSavepoint =
-                declineSynchronousSavepoint(jobId, coordinator, attemptID1, expectedRootCause);
+                declineSynchronousSavepoint(
+                        graph.getJobID(), coordinator, attemptID1, expectedRootCause);
 
         assertTrue(syncSavepoint.isDisposed());
 
@@ -2833,7 +2965,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
         int maxConcurrentCheckpoints = 1;
         int checkpointRequestsToSend = 10;
         int activeRequests = 0;
-        JobID jobId = new JobID();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .build();
         CheckpointCoordinator coordinator =
                 new CheckpointCoordinatorBuilder()
                         .setCheckpointCoordinatorConfiguration(
@@ -2841,7 +2977,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
                                         .setUnalignedCheckpointsEnabled(true)
                                         .setMaxConcurrentCheckpoints(maxConcurrentCheckpoints)
                                         .build())
-                        .setJobId(jobId)
+                        .setExecutionGraph(graph)
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
         try {
@@ -2862,7 +2998,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
             coordinator.receiveDeclineMessage(
                     new DeclineCheckpoint(
-                            jobId,
+                            graph.getJobID(),
                             new ExecutionAttemptID(),
                             1L,
                             new CheckpointException(CHECKPOINT_DECLINED)),
@@ -2892,30 +3028,27 @@ public class CheckpointCoordinatorTest extends TestLogger {
      */
     @Test
     public void testExternallyInducedSourceWithOperatorCoordinator() throws Exception {
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 =
-                mockExecutionVertex(
-                        attemptID1,
-                        (executionAttemptID,
-                                jid,
-                                checkpointId,
-                                timestamp,
-                                checkpointOptions) -> {});
-        ExecutionVertex vertex2 =
-                mockExecutionVertex(
-                        attemptID2,
-                        (executionAttemptID,
-                                jid,
-                                checkpointId,
-                                timestamp,
-                                checkpointOptions) -> {});
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-        OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-        OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+
+        OperatorID opID1 = vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
+        OperatorID opID2 = vertex2.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
         TaskStateSnapshot taskOperatorSubtaskStates1 = new TaskStateSnapshot();
         TaskStateSnapshot taskOperatorSubtaskStates2 = new TaskStateSnapshot();
         OperatorSubtaskState subtaskState1 = OperatorSubtaskState.builder().build();
@@ -2940,8 +3073,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
-                        .setTasks(new ExecutionVertex[] {vertex1, vertex2})
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(
                                 CheckpointCoordinatorConfiguration.builder()
                                         .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
@@ -2975,14 +3107,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
                         checkpointIdRef.set(checkpointId);
                         AcknowledgeCheckpoint acknowledgeCheckpoint1 =
                                 new AcknowledgeCheckpoint(
-                                        jobId,
+                                        graph.getJobID(),
                                         attemptID1,
                                         checkpointId,
                                         new CheckpointMetrics(),
                                         taskOperatorSubtaskStates1);
                         AcknowledgeCheckpoint acknowledgeCheckpoint2 =
                                 new AcknowledgeCheckpoint(
-                                        jobId,
+                                        graph.getJobID(),
                                         attemptID2,
                                         checkpointId,
                                         new CheckpointMetrics(),
@@ -3040,13 +3172,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
         // validate that the relevant tasks got a confirmation message
         long checkpointId = checkpointIdRef.get();
-        verify(vertex1.getCurrentExecutionAttempt(), times(1))
-                .triggerCheckpoint(eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
-        verify(vertex2.getCurrentExecutionAttempt(), times(1))
-                .triggerCheckpoint(eq(checkpointId), any(Long.class), any(CheckpointOptions.class));
+        for (ExecutionVertex vertex : Arrays.asList(vertex1, vertex2)) {
+            ExecutionAttemptID attemptId = vertex.getCurrentExecutionAttempt().getAttemptId();
+            assertEquals(checkpointId, gateway.getOnlyTriggeredCheckpoint(attemptId).checkpointId);
+        }
 
         CompletedCheckpoint success = checkpointCoordinator.getSuccessfulCheckpoints().get(0);
-        assertEquals(jobId, success.getJobId());
+        assertEquals(graph.getJobID(), success.getJobId());
         assertEquals(2, success.getOperatorStates().size());
 
         checkpointCoordinator.shutdown();
@@ -3054,30 +3186,23 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     @Test
     public void testCompleteCheckpointFailureWithExternallyInducedSource() throws Exception {
-        final JobID jobId = new JobID();
+        JobVertexID jobVertexID1 = new JobVertexID();
+        JobVertexID jobVertexID2 = new JobVertexID();
 
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
-        ExecutionVertex vertex1 =
-                mockExecutionVertex(
-                        attemptID1,
-                        (executionAttemptID,
-                                jid,
-                                checkpointId,
-                                timestamp,
-                                checkpointOptions) -> {});
-        ExecutionVertex vertex2 =
-                mockExecutionVertex(
-                        attemptID2,
-                        (executionAttemptID,
-                                jid,
-                                checkpointId,
-                                timestamp,
-                                checkpointOptions) -> {});
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID1)
+                        .addJobVertex(jobVertexID2)
+                        .build();
 
-        OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-        OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+        ExecutionVertex vertex1 = graph.getJobVertex(jobVertexID1).getTaskVertices()[0];
+        ExecutionVertex vertex2 = graph.getJobVertex(jobVertexID2).getTaskVertices()[0];
+
+        ExecutionAttemptID attemptID1 = vertex1.getCurrentExecutionAttempt().getAttemptId();
+        ExecutionAttemptID attemptID2 = vertex2.getCurrentExecutionAttempt().getAttemptId();
+
+        OperatorID opID1 = vertex1.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
+        OperatorID opID2 = vertex2.getJobVertex().getOperatorIDs().get(0).getGeneratedOperatorID();
         TaskStateSnapshot taskOperatorSubtaskStates1 = new TaskStateSnapshot();
         TaskStateSnapshot taskOperatorSubtaskStates2 = new TaskStateSnapshot();
         OperatorSubtaskState subtaskState1 = OperatorSubtaskState.builder().build();
@@ -3102,8 +3227,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
-                        .setTasks(new ExecutionVertex[] {vertex1, vertex2})
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(
                                 CheckpointCoordinatorConfiguration.builder()
                                         .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
@@ -3166,14 +3290,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
                         checkpointIdRef.set(checkpointId);
                         AcknowledgeCheckpoint acknowledgeCheckpoint1 =
                                 new AcknowledgeCheckpoint(
-                                        jobId,
+                                        graph.getJobID(),
                                         attemptID1,
                                         checkpointId,
                                         new CheckpointMetrics(),
                                         taskOperatorSubtaskStates1);
                         AcknowledgeCheckpoint acknowledgeCheckpoint2 =
                                 new AcknowledgeCheckpoint(
-                                        jobId,
+                                        graph.getJobID(),
                                         attemptID2,
                                         checkpointId,
                                         new CheckpointMetrics(),
@@ -3222,9 +3346,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
     @Test
     public void testNotifyCheckpointAbortionInOperatorCoordinator() throws Exception {
-        JobID jobId = new JobID();
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        ExecutionVertex executionVertex = mockExecutionVertex(attemptID);
+        JobVertexID jobVertexID = new JobVertexID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .build();
+
+        ExecutionVertex executionVertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = executionVertex.getCurrentExecutionAttempt().getAttemptId();
 
         CheckpointCoordinatorTestingUtils.MockOperatorCoordinatorCheckpointContext context =
                 new CheckpointCoordinatorTestingUtils
@@ -3237,8 +3366,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jobId)
-                        .setTasks(new ExecutionVertex[] {executionVertex})
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(
                                 CheckpointCoordinatorConfiguration.builder()
                                         .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
@@ -3261,7 +3389,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
                     Collections.max(checkpointCoordinator.getPendingCheckpoints().keySet());
             AcknowledgeCheckpoint acknowledgeCheckpoint1 =
                     new AcknowledgeCheckpoint(
-                            jobId, attemptID, checkpointId2, new CheckpointMetrics(), null);
+                            graph.getJobID(),
+                            attemptID,
+                            checkpointId2,
+                            new CheckpointMetrics(),
+                            null);
             checkpointCoordinator.receiveAcknowledgeMessage(acknowledgeCheckpoint1, "");
 
             // OperatorCoordinator should have been notified of the abortion of checkpoint 1.
@@ -3272,12 +3404,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
         }
     }
 
-    private CheckpointCoordinator getCheckpointCoordinator(
-            JobID jobId, ExecutionVertex vertex1, ExecutionVertex vertex2) {
-
+    private CheckpointCoordinator getCheckpointCoordinator(ExecutionGraph graph) throws Exception {
         return new CheckpointCoordinatorBuilder()
-                .setJobId(jobId)
-                .setTasks(new ExecutionVertex[] {vertex1, vertex2})
+                .setExecutionGraph(graph)
                 .setCheckpointCoordinatorConfiguration(
                         CheckpointCoordinatorConfiguration.builder()
                                 .setAlignmentTimeout(Long.MAX_VALUE)
@@ -3288,51 +3417,25 @@ public class CheckpointCoordinatorTest extends TestLogger {
     }
 
     private CheckpointCoordinator getCheckpointCoordinator(
-            JobID jobId,
-            ExecutionVertex vertex1,
-            ExecutionVertex vertex2,
-            CheckpointFailureManager failureManager) {
+            ExecutionGraph graph, CheckpointFailureManager failureManager) throws Exception {
 
         return new CheckpointCoordinatorBuilder()
-                .setJobId(jobId)
-                .setTasks(new ExecutionVertex[] {vertex1, vertex2})
+                .setExecutionGraph(graph)
                 .setTimer(manuallyTriggeredScheduledExecutor)
                 .setFailureManager(failureManager)
                 .build();
     }
 
-    private CheckpointCoordinator getCheckpointCoordinator(ExecutionState triggerVertexState) {
-        return getCheckpointCoordinator(manuallyTriggeredScheduledExecutor, triggerVertexState);
-    }
-
-    private CheckpointCoordinator getCheckpointCoordinator(
-            ScheduledExecutor timer, ExecutionState triggerVertexState) {
-        final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID();
-        ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1);
-        JobVertexID jobVertexID2 = new JobVertexID();
-        ExecutionVertex triggerVertex2 =
-                mockExecutionVertex(
-                        triggerAttemptID2,
-                        jobVertexID2,
-                        Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID2)),
-                        1,
-                        1,
-                        triggerVertexState);
-
-        // create some mock Execution vertices that need to ack the checkpoint
-        final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID();
-        final ExecutionAttemptID ackAttemptID2 = new ExecutionAttemptID();
-        ExecutionVertex ackVertex1 = mockExecutionVertex(ackAttemptID1);
-        ExecutionVertex ackVertex2 = mockExecutionVertex(ackAttemptID2);
+    private CheckpointCoordinator getCheckpointCoordinator(ScheduledExecutor timer)
+            throws Exception {
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .addJobVertex(new JobVertexID())
+                        .build();
 
         // set up the coordinator and validate the initial state
-        return new CheckpointCoordinatorBuilder()
-                .setTasksToTrigger(new ExecutionVertex[] {triggerVertex1, triggerVertex2})
-                .setTasksToWaitFor(new ExecutionVertex[] {ackVertex1, ackVertex2})
-                .setTasksToCommitTo(new ExecutionVertex[] {})
-                .setTimer(timer)
-                .build();
+        return new CheckpointCoordinatorBuilder().setExecutionGraph(graph).setTimer(timer).build();
     }
 
     private CheckpointFailureManager getCheckpointFailureManager(String errorMsg) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
index 348f3ce..479717d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -19,24 +19,25 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
-import org.apache.flink.mock.Whitebox;
 import org.apache.flink.runtime.OperatorIDPair;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
 import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
-import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.IntermediateResult;
 import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway;
-import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway.CheckpointConsumer;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
@@ -56,12 +57,12 @@ import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SharedStateRegistryFactory;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 
 import org.junit.Assert;
-import org.mockito.invocation.InvocationOnMock;
 
 import javax.annotation.Nullable;
 
@@ -69,7 +70,6 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.nio.charset.StandardCharsets;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
@@ -84,12 +84,7 @@ import java.util.function.Consumer;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyLong;
-import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.when;
 
 /** Testing utils for checkpoint coordinator. */
 public class CheckpointCoordinatorTestingUtils {
@@ -356,154 +351,6 @@ public class CheckpointCoordinatorTestingUtils {
         }
     }
 
-    static ExecutionJobVertex mockExecutionJobVertex(
-            JobVertexID jobVertexID, int parallelism, int maxParallelism) throws Exception {
-
-        return mockExecutionJobVertex(
-                jobVertexID,
-                Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-                parallelism,
-                maxParallelism);
-    }
-
-    static ExecutionJobVertex mockExecutionJobVertex(
-            JobVertexID jobVertexID,
-            List<OperatorID> jobVertexIDs,
-            int parallelism,
-            int maxParallelism)
-            throws Exception {
-        final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
-
-        ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
-
-        for (int i = 0; i < parallelism; i++) {
-            executionVertices[i] =
-                    mockExecutionVertex(
-                            new ExecutionAttemptID(),
-                            jobVertexID,
-                            jobVertexIDs,
-                            parallelism,
-                            maxParallelism,
-                            ExecutionState.RUNNING);
-
-            when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
-            when(executionVertices[i].getJobVertex()).thenReturn(executionJobVertex);
-        }
-
-        when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
-        when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
-        when(executionJobVertex.getParallelism()).thenReturn(parallelism);
-        when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
-        when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
-        List<OperatorIDPair> operatorIDPairs = new ArrayList<>();
-        for (OperatorID operatorID : jobVertexIDs) {
-            operatorIDPairs.add(OperatorIDPair.generatedIDOnly(operatorID));
-        }
-        when(executionJobVertex.getOperatorIDs()).thenReturn(operatorIDPairs);
-        when(executionJobVertex.getProducedDataSets()).thenReturn(new IntermediateResult[0]);
-
-        return executionJobVertex;
-    }
-
-    static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
-        return mockExecutionVertex(attemptID, (LogicalSlot) null);
-    }
-
-    static ExecutionVertex mockExecutionVertex(
-            ExecutionAttemptID attemptID, CheckpointConsumer checkpointConsumer) {
-
-        final SimpleAckingTaskManagerGateway taskManagerGateway =
-                new SimpleAckingTaskManagerGateway();
-        taskManagerGateway.setCheckpointConsumer(checkpointConsumer);
-        return mockExecutionVertex(attemptID, taskManagerGateway);
-    }
-
-    static ExecutionVertex mockExecutionVertex(
-            ExecutionAttemptID attemptID, TaskManagerGateway taskManagerGateway) {
-
-        TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
-        slotBuilder.setTaskManagerGateway(taskManagerGateway);
-        LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
-        return mockExecutionVertex(attemptID, slot);
-    }
-
-    static ExecutionVertex mockExecutionVertex(
-            ExecutionAttemptID attemptID, @Nullable LogicalSlot slot) {
-
-        JobVertexID jobVertexID = new JobVertexID();
-        return mockExecutionVertex(
-                attemptID,
-                jobVertexID,
-                Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-                slot,
-                1,
-                1,
-                ExecutionState.RUNNING);
-    }
-
-    static ExecutionVertex mockExecutionVertex(
-            ExecutionAttemptID attemptID,
-            JobVertexID jobVertexID,
-            List<OperatorID> jobVertexIDs,
-            int parallelism,
-            int maxParallelism,
-            ExecutionState state,
-            ExecutionState... successiveStates) {
-
-        return mockExecutionVertex(
-                attemptID,
-                jobVertexID,
-                jobVertexIDs,
-                null,
-                parallelism,
-                maxParallelism,
-                state,
-                successiveStates);
-    }
-
-    static ExecutionVertex mockExecutionVertex(
-            ExecutionAttemptID attemptID,
-            JobVertexID jobVertexID,
-            List<OperatorID> jobVertexIDs,
-            @Nullable LogicalSlot slot,
-            int parallelism,
-            int maxParallelism,
-            ExecutionState state,
-            ExecutionState... successiveStates) {
-
-        ExecutionVertex vertex = mock(ExecutionVertex.class);
-        when(vertex.getID()).thenReturn(ExecutionGraphTestUtils.createRandomExecutionVertexId());
-        when(vertex.getJobId()).thenReturn(new JobID());
-
-        final Execution exec =
-                spy(new Execution(mock(Executor.class), vertex, 1, 1L, Time.milliseconds(500L)));
-        if (slot != null) {
-            // is there a better way to do this?
-            Whitebox.setInternalState(exec, "assignedResource", slot);
-        }
-
-        when(exec.getAttemptId()).thenReturn(attemptID);
-        when(exec.getState()).thenReturn(state, successiveStates);
-
-        when(vertex.getJobvertexId()).thenReturn(jobVertexID);
-        when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
-        when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-        when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
-
-        ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
-        List<OperatorIDPair> operatorIDPairs = new ArrayList<>();
-        for (OperatorID operatorID : jobVertexIDs) {
-            operatorIDPairs.add(OperatorIDPair.generatedIDOnly(operatorID));
-        }
-        when(jobVertex.getOperatorIDs()).thenReturn(operatorIDPairs);
-        when(jobVertex.getJobVertexId()).thenReturn(jobVertexID);
-        when(jobVertex.getParallelism()).thenReturn(parallelism);
-
-        when(vertex.getJobVertex()).thenReturn(jobVertex);
-
-        return vertex;
-    }
-
     static TaskStateSnapshot mockSubtaskState(
             JobVertexID jobVertexID, int index, KeyGroupRange keyGroupRange) throws IOException {
 
@@ -563,81 +410,249 @@ public class CheckpointCoordinatorTestingUtils {
         return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
     }
 
-    static Execution mockExecution() {
-        Execution mock = mock(Execution.class);
-        when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
-        when(mock.getState()).thenReturn(ExecutionState.RUNNING);
-        return mock;
+    static class TriggeredCheckpoint {
+        final JobID jobId;
+        final long checkpointId;
+        final long timestamp;
+        final CheckpointOptions checkpointOptions;
+
+        public TriggeredCheckpoint(
+                JobID jobId,
+                long checkpointId,
+                long timestamp,
+                CheckpointOptions checkpointOptions) {
+            this.jobId = jobId;
+            this.checkpointId = checkpointId;
+            this.timestamp = timestamp;
+            this.checkpointOptions = checkpointOptions;
+        }
     }
 
-    static Execution mockExecution(CheckpointConsumer checkpointConsumer) {
-        ExecutionVertex executionVertex = mock(ExecutionVertex.class);
-        final JobID jobId = new JobID();
-        when(executionVertex.getJobId()).thenReturn(jobId);
-        Execution mock = mock(Execution.class);
-        ExecutionAttemptID executionAttemptID = new ExecutionAttemptID();
-        when(mock.getAttemptId()).thenReturn(executionAttemptID);
-        when(mock.getState()).thenReturn(ExecutionState.RUNNING);
-        when(mock.getVertex()).thenReturn(executionVertex);
-        doAnswer(
-                        (InvocationOnMock invocation) -> {
-                            final Object[] args = invocation.getArguments();
-                            checkpointConsumer.accept(
-                                    executionAttemptID,
-                                    jobId,
-                                    (long) args[0],
-                                    (long) args[1],
-                                    (CheckpointOptions) args[2]);
-                            return null;
-                        })
-                .when(mock)
-                .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
-        return mock;
+    static class NotifiedCheckpoint {
+        final JobID jobId;
+        final long checkpointId;
+        final long timestamp;
+
+        public NotifiedCheckpoint(JobID jobId, long checkpointId, long timestamp) {
+            this.jobId = jobId;
+            this.checkpointId = checkpointId;
+            this.timestamp = timestamp;
+        }
     }
 
-    static ExecutionVertex mockExecutionVertex(
-            Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
-        ExecutionVertex mock = mock(ExecutionVertex.class);
-        when(mock.getJobvertexId()).thenReturn(vertexId);
-        when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
-        when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
-        when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-        when(mock.getMaxParallelism()).thenReturn(parallelism);
-        return mock;
+    static class CheckpointRecorderTaskManagerGateway extends SimpleAckingTaskManagerGateway {
+
+        private final Map<ExecutionAttemptID, List<TriggeredCheckpoint>> triggeredCheckpoints =
+                new HashMap<>();
+
+        private final Map<ExecutionAttemptID, List<NotifiedCheckpoint>>
+                notifiedCompletedCheckpoints = new HashMap<>();
+
+        private final Map<ExecutionAttemptID, List<NotifiedCheckpoint>> notifiedAbortCheckpoints =
+                new HashMap<>();
+
+        @Override
+        public void triggerCheckpoint(
+                ExecutionAttemptID attemptId,
+                JobID jobId,
+                long checkpointId,
+                long timestamp,
+                CheckpointOptions checkpointOptions) {
+            triggeredCheckpoints
+                    .computeIfAbsent(attemptId, k -> new ArrayList<>())
+                    .add(
+                            new TriggeredCheckpoint(
+                                    jobId, checkpointId, timestamp, checkpointOptions));
+        }
+
+        @Override
+        public void notifyCheckpointComplete(
+                ExecutionAttemptID attemptId, JobID jobId, long checkpointId, long timestamp) {
+            notifiedCompletedCheckpoints
+                    .computeIfAbsent(attemptId, k -> new ArrayList<>())
+                    .add(new NotifiedCheckpoint(jobId, checkpointId, timestamp));
+        }
+
+        @Override
+        public void notifyCheckpointAborted(
+                ExecutionAttemptID attemptId, JobID jobId, long checkpointId, long timestamp) {
+            notifiedAbortCheckpoints
+                    .computeIfAbsent(attemptId, k -> new ArrayList<>())
+                    .add(new NotifiedCheckpoint(jobId, checkpointId, timestamp));
+        }
+
+        public void resetCount() {
+            triggeredCheckpoints.clear();
+            notifiedAbortCheckpoints.clear();
+            notifiedCompletedCheckpoints.clear();
+        }
+
+        public List<TriggeredCheckpoint> getTriggeredCheckpoints(ExecutionAttemptID attemptId) {
+            return triggeredCheckpoints.getOrDefault(attemptId, Collections.emptyList());
+        }
+
+        public TriggeredCheckpoint getOnlyTriggeredCheckpoint(ExecutionAttemptID attemptId) {
+            List<TriggeredCheckpoint> triggeredCheckpoints = getTriggeredCheckpoints(attemptId);
+            assertEquals(
+                    "There should be exactly one checkpoint triggered for " + attemptId,
+                    1,
+                    triggeredCheckpoints.size());
+            return triggeredCheckpoints.get(0);
+        }
+
+        public List<NotifiedCheckpoint> getNotifiedCompletedCheckpoints(
+                ExecutionAttemptID attemptId) {
+            return notifiedCompletedCheckpoints.getOrDefault(attemptId, Collections.emptyList());
+        }
+
+        public NotifiedCheckpoint getOnlyNotifiedCompletedCheckpoint(ExecutionAttemptID attemptId) {
+            List<NotifiedCheckpoint> completedCheckpoints =
+                    getNotifiedCompletedCheckpoints(attemptId);
+            assertEquals(
+                    "There should be exactly one checkpoint notified completed for " + attemptId,
+                    1,
+                    completedCheckpoints.size());
+            return completedCheckpoints.get(0);
+        }
+
+        public List<NotifiedCheckpoint> getNotifiedAbortedCheckpoints(
+                ExecutionAttemptID attemptId) {
+            return notifiedAbortCheckpoints.getOrDefault(attemptId, Collections.emptyList());
+        }
+
+        public NotifiedCheckpoint getOnlyNotifiedAbortedCheckpoint(ExecutionAttemptID attemptId) {
+            List<NotifiedCheckpoint> abortedCheckpoints = getNotifiedAbortedCheckpoints(attemptId);
+            assertEquals(
+                    "There should be exactly one checkpoint notified aborted for " + attemptId,
+                    1,
+                    abortedCheckpoints.size());
+            return abortedCheckpoints.get(0);
+        }
     }
 
-    static ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
-        ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
-        when(vertex.getParallelism()).thenReturn(vertices.length);
-        when(vertex.getMaxParallelism()).thenReturn(vertices.length);
-        when(vertex.getJobVertexId()).thenReturn(id);
-        when(vertex.getTaskVertices()).thenReturn(vertices);
-        when(vertex.getOperatorIDs())
-                .thenReturn(
-                        Collections.singletonList(
-                                OperatorIDPair.generatedIDOnly(OperatorID.fromJobVertexID(id))));
-        when(vertex.getProducedDataSets()).thenReturn(new IntermediateResult[0]);
-
-        for (ExecutionVertex v : vertices) {
-            when(v.getJobVertex()).thenReturn(vertex);
-        }
-        return vertex;
+    static class CheckpointExecutionGraphBuilder {
+        private final List<JobVertex> sourceVertices = new ArrayList<>();
+        private final List<JobVertex> nonSourceVertices = new ArrayList<>();
+        private boolean transitToRunning;
+        private TaskManagerGateway taskManagerGateway;
+        private ComponentMainThreadExecutor mainThreadExecutor;
+
+        CheckpointExecutionGraphBuilder() {
+            this.transitToRunning = true;
+            this.mainThreadExecutor = ComponentMainThreadExecutorServiceAdapter.forMainThread();
+        }
+
+        public CheckpointExecutionGraphBuilder addJobVertex(JobVertexID id) {
+            return addJobVertex(id, true);
+        }
+
+        public CheckpointExecutionGraphBuilder addJobVertex(JobVertexID id, boolean isSource) {
+            return addJobVertex(id, 1, 32768, Collections.emptyList(), isSource);
+        }
+
+        public CheckpointExecutionGraphBuilder addJobVertex(
+                JobVertexID id, int parallelism, int maxParallelism) {
+            return addJobVertex(id, parallelism, maxParallelism, Collections.emptyList(), true);
+        }
+
+        public CheckpointExecutionGraphBuilder addJobVertex(
+                JobVertexID id,
+                int parallelism,
+                int maxParallelism,
+                List<OperatorIDPair> operators,
+                boolean isSource) {
+
+            JobVertex jobVertex =
+                    operators.size() == 0
+                            ? new JobVertex("anon", id)
+                            : new JobVertex("anon", id, operators);
+            jobVertex.setParallelism(parallelism);
+            jobVertex.setMaxParallelism(maxParallelism);
+            jobVertex.setInvokableClass(NoOpInvokable.class);
+
+            return addJobVertex(jobVertex, isSource);
+        }
+
+        public CheckpointExecutionGraphBuilder addJobVertex(JobVertex jobVertex, boolean isSource) {
+            if (isSource) {
+                sourceVertices.add(jobVertex);
+            } else {
+                nonSourceVertices.add(jobVertex);
+            }
+
+            return this;
+        }
+
+        public CheckpointExecutionGraphBuilder setTaskManagerGateway(
+                TaskManagerGateway taskManagerGateway) {
+            this.taskManagerGateway = taskManagerGateway;
+            return this;
+        }
+
+        public CheckpointExecutionGraphBuilder setTransitToRunning(boolean transitToRunning) {
+            this.transitToRunning = transitToRunning;
+            return this;
+        }
+
+        public CheckpointExecutionGraphBuilder setMainThreadExecutor(
+                ComponentMainThreadExecutor mainThreadExecutor) {
+            this.mainThreadExecutor = mainThreadExecutor;
+            return this;
+        }
+
+        ExecutionGraph build() throws Exception {
+            // Lets connect source vertices and non-source vertices
+            for (JobVertex source : sourceVertices) {
+                for (JobVertex nonSource : nonSourceVertices) {
+                    nonSource.connectNewDataSetAsInput(
+                            source, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+                }
+            }
+
+            List<JobVertex> allVertices = new ArrayList<>();
+            allVertices.addAll(sourceVertices);
+            allVertices.addAll(nonSourceVertices);
+
+            ExecutionGraph executionGraph =
+                    ExecutionGraphTestUtils.createSimpleTestGraph(
+                            allVertices.toArray(new JobVertex[0]));
+            executionGraph.start(mainThreadExecutor);
+
+            if (taskManagerGateway != null) {
+                executionGraph
+                        .getAllExecutionVertices()
+                        .forEach(
+                                task -> {
+                                    LogicalSlot slot =
+                                            new TestingLogicalSlotBuilder()
+                                                    .setTaskManagerGateway(taskManagerGateway)
+                                                    .createTestingLogicalSlot();
+                                    task.tryAssignResource(slot);
+                                });
+            }
+
+            if (transitToRunning) {
+                executionGraph.transitionToRunning();
+                executionGraph
+                        .getAllExecutionVertices()
+                        .forEach(
+                                task ->
+                                        task.getCurrentExecutionAttempt()
+                                                .transitionState(ExecutionState.RUNNING));
+            }
+
+            return executionGraph;
+        }
     }
 
     /** A helper builder for {@link CheckpointCoordinator} to deduplicate test codes. */
     public static class CheckpointCoordinatorBuilder {
-        private JobID jobId = new JobID();
-
         private CheckpointCoordinatorConfiguration checkpointCoordinatorConfiguration =
                 new CheckpointCoordinatorConfigurationBuilder()
                         .setMaxConcurrentCheckpoints(Integer.MAX_VALUE)
                         .build();
 
-        private ExecutionVertex[] tasksToTrigger;
-
-        private ExecutionVertex[] tasksToWaitFor;
-
-        private ExecutionVertex[] tasksToCommitTo;
+        private ExecutionGraph executionGraph;
 
         private Collection<OperatorCoordinatorCheckpointContext> coordinatorsToCheckpoint =
                 Collections.emptyList();
@@ -661,18 +676,7 @@ public class CheckpointCoordinatorTestingUtils {
         private CheckpointFailureManager failureManager =
                 new CheckpointFailureManager(0, NoOpFailJobCall.INSTANCE);
 
-        public CheckpointCoordinatorBuilder() {
-            ExecutionVertex vertex = mockExecutionVertex(new ExecutionAttemptID());
-            ExecutionVertex[] defaultVertices = new ExecutionVertex[] {vertex};
-            tasksToTrigger = defaultVertices;
-            tasksToWaitFor = defaultVertices;
-            tasksToCommitTo = defaultVertices;
-        }
-
-        public CheckpointCoordinatorBuilder setJobId(JobID jobId) {
-            this.jobId = jobId;
-            return this;
-        }
+        private boolean allowCheckpointsAfterTasksFinished;
 
         public CheckpointCoordinatorBuilder setCheckpointCoordinatorConfiguration(
                 CheckpointCoordinatorConfiguration checkpointCoordinatorConfiguration) {
@@ -680,25 +684,8 @@ public class CheckpointCoordinatorTestingUtils {
             return this;
         }
 
-        public CheckpointCoordinatorBuilder setTasks(ExecutionVertex... tasks) {
-            this.tasksToTrigger = tasks;
-            this.tasksToWaitFor = tasks;
-            this.tasksToCommitTo = tasks;
-            return this;
-        }
-
-        public CheckpointCoordinatorBuilder setTasksToTrigger(ExecutionVertex[] tasksToTrigger) {
-            this.tasksToTrigger = tasksToTrigger;
-            return this;
-        }
-
-        public CheckpointCoordinatorBuilder setTasksToWaitFor(ExecutionVertex[] tasksToWaitFor) {
-            this.tasksToWaitFor = tasksToWaitFor;
-            return this;
-        }
-
-        public CheckpointCoordinatorBuilder setTasksToCommitTo(ExecutionVertex[] tasksToCommitTo) {
-            this.tasksToCommitTo = tasksToCommitTo;
+        public CheckpointCoordinatorBuilder setExecutionGraph(ExecutionGraph executionGraph) {
+            this.executionGraph = executionGraph;
             return this;
         }
 
@@ -753,9 +740,30 @@ public class CheckpointCoordinatorTestingUtils {
             return this;
         }
 
-        public CheckpointCoordinator build() {
+        public CheckpointCoordinatorBuilder setAllowCheckpointsAfterTasksFinished(
+                boolean allowCheckpointsAfterTasksFinished) {
+            this.allowCheckpointsAfterTasksFinished = allowCheckpointsAfterTasksFinished;
+            return this;
+        }
+
+        public CheckpointCoordinator build() throws Exception {
+            if (executionGraph == null) {
+                executionGraph =
+                        new CheckpointExecutionGraphBuilder()
+                                .addJobVertex(new JobVertexID())
+                                .build();
+            }
+
+            DefaultCheckpointPlanCalculator checkpointPlanCalculator =
+                    new DefaultCheckpointPlanCalculator(
+                            executionGraph.getJobID(),
+                            new ExecutionGraphCheckpointPlanCalculatorContext(executionGraph),
+                            executionGraph.getVerticesTopologically());
+            checkpointPlanCalculator.setAllowCheckpointsAfterTasksFinished(
+                    allowCheckpointsAfterTasksFinished);
+
             return new CheckpointCoordinator(
-                    jobId,
+                    executionGraph.getJobID(),
                     checkpointCoordinatorConfiguration,
                     coordinatorsToCheckpoint,
                     checkpointIDCounter,
@@ -766,12 +774,8 @@ public class CheckpointCoordinatorTestingUtils {
                     timer,
                     sharedStateRegistryFactory,
                     failureManager,
-                    new CheckpointPlanCalculator(
-                            jobId,
-                            Arrays.asList(tasksToTrigger),
-                            Arrays.asList(tasksToWaitFor),
-                            Arrays.asList(tasksToCommitTo)),
-                    new ExecutionAttemptMappingProvider(Arrays.asList(tasksToWaitFor)));
+                    checkpointPlanCalculator,
+                    new ExecutionAttemptMappingProvider(executionGraph.getAllExecutionVertices()));
         }
     }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
index 1b4d154..ca06050 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
@@ -18,17 +18,16 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
-import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration.CheckpointCoordinatorConfigurationBuilder;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
@@ -38,24 +37,19 @@ import org.apache.flink.util.TestLogger;
 
 import org.junit.Before;
 import org.junit.Test;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import javax.annotation.Nullable;
 
+import java.util.List;
 import java.util.Optional;
-import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.Executors;
-import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
 import java.util.function.Predicate;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.hamcrest.CoreMatchers.instanceOf;
 import static org.junit.Assert.assertEquals;
@@ -63,9 +57,6 @@ import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.anyLong;
-import static org.mockito.Mockito.doAnswer;
 
 /** Tests for checkpoint coordinator triggering. */
 public class CheckpointCoordinatorTriggeringTest extends TestLogger {
@@ -81,46 +72,20 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
     @Test
     public void testPeriodicTriggering() {
         try {
-            final JobID jid = new JobID();
             final long start = System.currentTimeMillis();
 
-            // create some mock execution vertices and trigger some checkpoint
+            CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                    new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-            final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
-            final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-            ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-            ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
-            ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
-
-            final AtomicInteger numCalls = new AtomicInteger();
-
-            final Execution execution = triggerVertex.getCurrentExecutionAttempt();
-
-            doAnswer(
-                            new Answer<Void>() {
-
-                                private long lastId = -1;
-                                private long lastTs = -1;
-
-                                @Override
-                                public Void answer(InvocationOnMock invocation) throws Throwable {
-                                    long id = (Long) invocation.getArguments()[0];
-                                    long ts = (Long) invocation.getArguments()[1];
-
-                                    assertTrue(id > lastId);
-                                    assertTrue(ts >= lastTs);
-                                    assertTrue(ts >= start);
+            JobVertexID jobVertexID = new JobVertexID();
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(jobVertexID)
+                            .setTaskManagerGateway(gateway)
+                            .build();
 
-                                    lastId = id;
-                                    lastTs = ts;
-                                    numCalls.incrementAndGet();
-                                    return null;
-                                }
-                            })
-                    .when(execution)
-                    .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+            ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+            ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
             CheckpointCoordinatorConfiguration checkpointCoordinatorConfiguration =
                     new CheckpointCoordinatorConfigurationBuilder()
@@ -130,47 +95,44 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
                             .build();
             CheckpointCoordinator checkpointCoordinator =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jid)
+                            .setExecutionGraph(graph)
                             .setCheckpointCoordinatorConfiguration(
                                     checkpointCoordinatorConfiguration)
-                            .setTasksToTrigger(new ExecutionVertex[] {triggerVertex})
-                            .setTasksToWaitFor(new ExecutionVertex[] {ackVertex})
-                            .setTasksToCommitTo(new ExecutionVertex[] {commitVertex})
                             .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
 
             checkpointCoordinator.startCheckpointScheduler();
 
-            do {
+            for (int i = 0; i < 5; ++i) {
                 manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
                 manuallyTriggeredScheduledExecutor.triggerAll();
-            } while (numCalls.get() < 5);
-            assertEquals(5, numCalls.get());
+            }
+            checkRecordedTriggeredCheckpoints(5, start, gateway.getTriggeredCheckpoints(attemptID));
 
             checkpointCoordinator.stopCheckpointScheduler();
 
             // no further calls may come.
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
             manuallyTriggeredScheduledExecutor.triggerAll();
-            assertEquals(5, numCalls.get());
+            assertEquals(5, gateway.getTriggeredCheckpoints(attemptID).size());
 
             // start another sequence of periodic scheduling
-            numCalls.set(0);
+            gateway.resetCount();
             checkpointCoordinator.startCheckpointScheduler();
 
-            do {
+            for (int i = 0; i < 5; ++i) {
                 manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
                 manuallyTriggeredScheduledExecutor.triggerAll();
-            } while (numCalls.get() < 5);
-            assertEquals(5, numCalls.get());
+            }
+            checkRecordedTriggeredCheckpoints(5, start, gateway.getTriggeredCheckpoints(attemptID));
 
             checkpointCoordinator.stopCheckpointScheduler();
 
             // no further calls may come
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
             manuallyTriggeredScheduledExecutor.triggerAll();
-            assertEquals(5, numCalls.get());
+            assertEquals(5, gateway.getTriggeredCheckpoints(attemptID).size());
 
             checkpointCoordinator.shutdown();
         } catch (Exception e) {
@@ -179,28 +141,50 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
         }
     }
 
+    private void checkRecordedTriggeredCheckpoints(
+            int numTrigger,
+            long start,
+            List<CheckpointCoordinatorTestingUtils.TriggeredCheckpoint> checkpoints) {
+        assertEquals(numTrigger, checkpoints.size());
+
+        long lastId = -1;
+        long lastTs = -1;
+
+        for (CheckpointCoordinatorTestingUtils.TriggeredCheckpoint checkpoint : checkpoints) {
+            assertTrue(
+                    "Trigger checkpoint id should be in increase order",
+                    checkpoint.checkpointId > lastId);
+            assertTrue(
+                    "Trigger checkpoint timestamp should be in increase order",
+                    checkpoint.timestamp >= lastTs);
+            assertTrue(
+                    "Trigger checkpoint timestamp should be larger than the start time",
+                    checkpoint.timestamp >= start);
+
+            lastId = checkpoint.checkpointId;
+            lastTs = checkpoint.timestamp;
+        }
+    }
+
     /**
      * This test verified that after a completed checkpoint a certain time has passed before another
      * is triggered.
      */
     @Test
     public void testMinTimeBetweenCheckpointsInterval() throws Exception {
-        final JobID jid = new JobID();
+        JobVertexID jobVertexID = new JobVertexID();
 
-        // create some mock execution vertices and trigger some checkpoint
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        final ExecutionVertex vertex = mockExecutionVertex(attemptID);
-        final Execution executionAttempt = vertex.getCurrentExecutionAttempt();
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
 
-        final BlockingQueue<Long> triggerCalls = new LinkedBlockingQueue<>();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .setTaskManagerGateway(gateway)
+                        .build();
 
-        doAnswer(
-                        invocation -> {
-                            triggerCalls.add((Long) invocation.getArguments()[0]);
-                            return null;
-                        })
-                .when(executionAttempt)
-                .triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+        ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
         final long delay = 50;
         final long checkpointInterval = 12;
@@ -214,9 +198,8 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
                         .build();
         final CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setJobId(jid)
+                        .setExecutionGraph(graph)
                         .setCheckpointCoordinatorConfiguration(checkpointCoordinatorConfiguration)
-                        .setTasks(new ExecutionVertex[] {vertex})
                         .setCompletedCheckpointStore(new StandaloneCompletedCheckpointStore(2))
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
@@ -227,25 +210,27 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
             manuallyTriggeredScheduledExecutor.triggerAll();
 
             // wait until the first checkpoint was triggered
-            Long firstCallId = triggerCalls.take();
+            Long firstCallId = gateway.getTriggeredCheckpoints(attemptID).get(0).checkpointId;
             assertEquals(1L, firstCallId.longValue());
 
-            AcknowledgeCheckpoint ackMsg = new AcknowledgeCheckpoint(jid, attemptID, 1L);
+            AcknowledgeCheckpoint ackMsg =
+                    new AcknowledgeCheckpoint(graph.getJobID(), attemptID, 1L);
 
             // tell the coordinator that the checkpoint is done
             final long ackTime = System.nanoTime();
             checkpointCoordinator.receiveAcknowledgeMessage(ackMsg, TASK_MANAGER_LOCATION_INFO);
 
+            gateway.resetCount();
             manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
             manuallyTriggeredScheduledExecutor.triggerAll();
-            while (triggerCalls.isEmpty()) {
+            while (gateway.getTriggeredCheckpoints(attemptID).isEmpty()) {
                 // sleeps for a while to simulate periodic scheduling
                 Thread.sleep(checkpointInterval);
                 manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
                 manuallyTriggeredScheduledExecutor.triggerAll();
             }
             // wait until the next checkpoint is triggered
-            Long nextCallId = triggerCalls.take();
+            Long nextCallId = gateway.getTriggeredCheckpoints(attemptID).get(0).checkpointId;
             final long nextCheckpointTime = System.nanoTime();
             assertEquals(2L, nextCallId.longValue());
 
@@ -321,16 +306,22 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 
     @Test
     public void testTriggerCheckpointBeforePreviousOneCompleted() throws Exception {
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        final AtomicInteger taskManagerCheckpointTriggeredTimes = new AtomicInteger(0);
-        final SimpleAckingTaskManagerGateway.CheckpointConsumer checkpointConsumer =
-                (executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions) ->
-                        taskManagerCheckpointTriggeredTimes.incrementAndGet();
-        ExecutionVertex vertex = mockExecutionVertex(attemptID, checkpointConsumer);
+        JobVertexID jobVertexID = new JobVertexID();
+
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
-        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator(vertex);
+        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator(graph);
 
         checkpointCoordinator.startCheckpointScheduler();
         // start a periodic checkpoint first
@@ -350,23 +341,29 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
         assertFalse(onCompletionPromise2.isCompletedExceptionally());
         assertFalse(checkpointCoordinator.isTriggering());
         assertEquals(0, checkpointCoordinator.getTriggerRequestQueue().size());
-        assertEquals(2, taskManagerCheckpointTriggeredTimes.get());
+        assertEquals(2, gateway.getTriggeredCheckpoints(attemptID).size());
     }
 
     @Test
     public void testTriggerCheckpointRequestQueuedWithFailure() throws Exception {
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        final AtomicInteger taskManagerCheckpointTriggeredTimes = new AtomicInteger(0);
-        final SimpleAckingTaskManagerGateway.CheckpointConsumer checkpointConsumer =
-                (executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions) ->
-                        taskManagerCheckpointTriggeredTimes.incrementAndGet();
-        ExecutionVertex vertex = mockExecutionVertex(attemptID, checkpointConsumer);
+        JobVertexID jobVertexID = new JobVertexID();
+
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setTasks(new ExecutionVertex[] {vertex})
+                        .setExecutionGraph(graph)
                         .setCheckpointIDCounter(new UnstableCheckpointIDCounter(id -> id == 0))
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
@@ -395,21 +392,27 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
         assertFalse(onCompletionPromise3.isCompletedExceptionally());
         assertFalse(checkpointCoordinator.isTriggering());
         assertEquals(0, checkpointCoordinator.getTriggerRequestQueue().size());
-        assertEquals(2, taskManagerCheckpointTriggeredTimes.get());
+        assertEquals(2, gateway.getTriggeredCheckpoints(attemptID).size());
     }
 
     @Test
     public void testTriggerCheckpointRequestCancelled() throws Exception {
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        final AtomicInteger taskManagerCheckpointTriggeredTimes = new AtomicInteger(0);
-        final SimpleAckingTaskManagerGateway.CheckpointConsumer checkpointConsumer =
-                (executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions) ->
-                        taskManagerCheckpointTriggeredTimes.incrementAndGet();
-        ExecutionVertex vertex = mockExecutionVertex(attemptID, checkpointConsumer);
+        JobVertexID jobVertexID = new JobVertexID();
+
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
-        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator(vertex);
+        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator(graph);
 
         final CompletableFuture<String> masterHookCheckpointFuture = new CompletableFuture<>();
         checkpointCoordinator.addMasterHook(new TestingMasterHook(masterHookCheckpointFuture));
@@ -443,20 +446,15 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
         manuallyTriggeredScheduledExecutor.triggerAll();
         assertFalse(checkpointCoordinator.isTriggering());
         // it doesn't really trigger task manager to do checkpoint
-        assertEquals(0, taskManagerCheckpointTriggeredTimes.get());
+        assertEquals(0, gateway.getTriggeredCheckpoints(attemptID).size());
         assertEquals(0, checkpointCoordinator.getTriggerRequestQueue().size());
     }
 
     @Test
     public void testTriggerCheckpointInitializationFailed() throws Exception {
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        ExecutionVertex vertex = mockExecutionVertex(attemptID);
-
         // set up the coordinator and validate the initial state
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorBuilder()
-                        .setTasks(new ExecutionVertex[] {vertex})
                         .setCheckpointIDCounter(new UnstableCheckpointIDCounter(id -> id == 0))
                         .setTimer(manuallyTriggeredScheduledExecutor)
                         .build();
@@ -493,16 +491,22 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 
     @Test
     public void testTriggerCheckpointSnapshotMasterHookFailed() throws Exception {
-        // create some mock Execution vertices that receive the checkpoint trigger messages
-        final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-        final AtomicInteger taskManagerCheckpointTriggeredTimes = new AtomicInteger(0);
-        final SimpleAckingTaskManagerGateway.CheckpointConsumer checkpointConsumer =
-                (executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions) ->
-                        taskManagerCheckpointTriggeredTimes.incrementAndGet();
-        ExecutionVertex vertex = mockExecutionVertex(attemptID, checkpointConsumer);
+        JobVertexID jobVertexID = new JobVertexID();
+
+        CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway gateway =
+                new CheckpointCoordinatorTestingUtils.CheckpointRecorderTaskManagerGateway();
+
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .setTaskManagerGateway(gateway)
+                        .build();
+
+        ExecutionVertex vertex = graph.getJobVertex(jobVertexID).getTaskVertices()[0];
+        ExecutionAttemptID attemptID = vertex.getCurrentExecutionAttempt().getAttemptId();
 
         // set up the coordinator and validate the initial state
-        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator(vertex);
+        CheckpointCoordinator checkpointCoordinator = createCheckpointCoordinator();
 
         final CompletableFuture<String> masterHookCheckpointFuture = new CompletableFuture<>();
         checkpointCoordinator.addMasterHook(new TestingMasterHook(masterHookCheckpointFuture));
@@ -532,20 +536,17 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
                     checkpointExceptionOptional.get().getCheckpointFailureReason());
         }
         // it doesn't really trigger task manager to do checkpoint
-        assertEquals(0, taskManagerCheckpointTriggeredTimes.get());
+        assertEquals(0, gateway.getTriggeredCheckpoints(attemptID).size());
         assertEquals(0, checkpointCoordinator.getTriggerRequestQueue().size());
     }
 
     /** This test only fails eventually. */
     @Test
     public void discardingTriggeringCheckpointWillExecuteNextCheckpointRequest() throws Exception {
-        final ExecutionVertex executionVertex = mockExecutionVertex(new ExecutionAttemptID());
-
         final ScheduledExecutorService scheduledExecutorService =
                 Executors.newSingleThreadScheduledExecutor();
         final CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinatorTestingUtils.CheckpointCoordinatorBuilder()
-                        .setTasks(new ExecutionVertex[] {executionVertex})
                         .setTimer(new ScheduledExecutorServiceAdapter(scheduledExecutorService))
                         .setCheckpointCoordinatorConfiguration(
                                 CheckpointCoordinatorConfiguration.builder().build())
@@ -584,15 +585,16 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
         }
     }
 
-    private CheckpointCoordinator createCheckpointCoordinator() {
+    private CheckpointCoordinator createCheckpointCoordinator() throws Exception {
         return new CheckpointCoordinatorBuilder()
                 .setTimer(manuallyTriggeredScheduledExecutor)
                 .build();
     }
 
-    private CheckpointCoordinator createCheckpointCoordinator(ExecutionVertex executionVertex) {
+    private CheckpointCoordinator createCheckpointCoordinator(ExecutionGraph graph)
+            throws Exception {
         return new CheckpointCoordinatorBuilder()
-                .setTasks(new ExecutionVertex[] {executionVertex})
+                .setExecutionGraph(graph)
                 .setTimer(manuallyTriggeredScheduledExecutor)
                 .build();
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 5e34085..9f97506 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
@@ -36,27 +37,22 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
 import org.apache.flink.util.SerializableObject;
 
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.Description;
 import org.junit.Test;
-import org.mockito.Mockito;
-import org.mockito.hamcrest.MockitoHamcrest;
 
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Set;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /** Tests concerning the restoring of state from a checkpoint to the task executions. */
@@ -68,7 +64,6 @@ public class CheckpointStateRestoreTest {
     @Test
     public void testSetState() {
         try {
-
             KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
             List<SerializableObject> testStates =
                     Collections.singletonList(new SerializableObject());
@@ -76,48 +71,36 @@ public class CheckpointStateRestoreTest {
                     CheckpointCoordinatorTestingUtils.generateKeyGroupState(
                             keyGroupRange, testStates);
 
-            final JobID jid = new JobID();
             final JobVertexID statefulId = new JobVertexID();
             final JobVertexID statelessId = new JobVertexID();
 
-            Execution statefulExec1 = mockExecution();
-            Execution statefulExec2 = mockExecution();
-            Execution statefulExec3 = mockExecution();
-            Execution statelessExec1 = mockExecution();
-            Execution statelessExec2 = mockExecution();
-
-            ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 3);
-            ExecutionVertex stateful2 = mockExecutionVertex(statefulExec2, statefulId, 1, 3);
-            ExecutionVertex stateful3 = mockExecutionVertex(statefulExec3, statefulId, 2, 3);
-            ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 2);
-            ExecutionVertex stateless2 = mockExecutionVertex(statelessExec2, statelessId, 1, 2);
-
-            ExecutionJobVertex stateful =
-                    mockExecutionJobVertex(
-                            statefulId, new ExecutionVertex[] {stateful1, stateful2, stateful3});
-            ExecutionJobVertex stateless =
-                    mockExecutionJobVertex(
-                            statelessId, new ExecutionVertex[] {stateless1, stateless2});
-
-            Set<ExecutionJobVertex> tasks = new HashSet<>();
-            tasks.add(stateful);
-            tasks.add(stateless);
+            ExecutionGraph graph =
+                    new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                            .addJobVertex(statefulId, 3, 256)
+                            .addJobVertex(statelessId, 2, 256)
+                            .build();
+
+            ExecutionJobVertex stateful = graph.getJobVertex(statefulId);
+            ExecutionJobVertex stateless = graph.getJobVertex(statelessId);
+
+            ExecutionVertex stateful1 = stateful.getTaskVertices()[0];
+            ExecutionVertex stateful2 = stateful.getTaskVertices()[1];
+            ExecutionVertex stateful3 = stateful.getTaskVertices()[2];
+            ExecutionVertex stateless1 = stateless.getTaskVertices()[0];
+            ExecutionVertex stateless2 = stateless.getTaskVertices()[1];
+
+            Execution statefulExec1 = stateful1.getCurrentExecutionAttempt();
+            Execution statefulExec2 = stateful2.getCurrentExecutionAttempt();
+            Execution statefulExec3 = stateful3.getCurrentExecutionAttempt();
+            Execution statelessExec1 = stateless1.getCurrentExecutionAttempt();
+            Execution statelessExec2 = stateless2.getCurrentExecutionAttempt();
 
             ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
                     new ManuallyTriggeredScheduledExecutor();
 
             CheckpointCoordinator coord =
                     new CheckpointCoordinatorBuilder()
-                            .setJobId(jid)
-                            .setTasksToTrigger(
-                                    new ExecutionVertex[] {
-                                        stateful1, stateful2, stateful3, stateless1, stateless2
-                                    })
-                            .setTasksToWaitFor(
-                                    new ExecutionVertex[] {
-                                        stateful1, stateful2, stateful3, stateless1, stateless2
-                                    })
-                            .setTasksToCommitTo(new ExecutionVertex[0])
+                            .setExecutionGraph(graph)
                             .setTimer(manuallyTriggeredScheduledExecutor)
                             .build();
 
@@ -138,7 +121,7 @@ public class CheckpointStateRestoreTest {
 
             coord.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             statefulExec1.getAttemptId(),
                             checkpointId,
                             new CheckpointMetrics(),
@@ -146,7 +129,7 @@ public class CheckpointStateRestoreTest {
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             statefulExec2.getAttemptId(),
                             checkpointId,
                             new CheckpointMetrics(),
@@ -154,50 +137,35 @@ public class CheckpointStateRestoreTest {
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
                     new AcknowledgeCheckpoint(
-                            jid,
+                            graph.getJobID(),
                             statefulExec3.getAttemptId(),
                             checkpointId,
                             new CheckpointMetrics(),
                             subtaskStates),
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId),
+                    new AcknowledgeCheckpoint(
+                            graph.getJobID(), statelessExec1.getAttemptId(), checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
             coord.receiveAcknowledgeMessage(
-                    new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId),
+                    new AcknowledgeCheckpoint(
+                            graph.getJobID(), statelessExec2.getAttemptId(), checkpointId),
                     TASK_MANAGER_LOCATION_INFO);
 
             assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
             assertEquals(0, coord.getNumberOfPendingCheckpoints());
 
             // let the coordinator inject the state
-            assertTrue(coord.restoreLatestCheckpointedStateToAll(tasks, false));
+            assertTrue(
+                    coord.restoreLatestCheckpointedStateToAll(
+                            new HashSet<>(Arrays.asList(stateful, stateless)), false));
 
             // verify that each stateful vertex got the state
-
-            BaseMatcher<JobManagerTaskRestore> matcher =
-                    new BaseMatcher<JobManagerTaskRestore>() {
-                        @Override
-                        public boolean matches(Object o) {
-                            if (o instanceof JobManagerTaskRestore) {
-                                JobManagerTaskRestore taskRestore = (JobManagerTaskRestore) o;
-                                return Objects.equals(
-                                        taskRestore.getTaskStateSnapshot(), subtaskStates);
-                            }
-                            return false;
-                        }
-
-                        @Override
-                        public void describeTo(Description description) {
-                            description.appendValue(subtaskStates);
-                        }
-                    };
-
-            verify(statefulExec1, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
-            verify(statefulExec2, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
-            verify(statefulExec3, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
-            verify(statelessExec1, times(0)).setInitialState(Mockito.<JobManagerTaskRestore>any());
-            verify(statelessExec2, times(0)).setInitialState(Mockito.<JobManagerTaskRestore>any());
+            assertEquals(subtaskStates, statefulExec1.getTaskRestore().getTaskStateSnapshot());
+            assertEquals(subtaskStates, statefulExec2.getTaskRestore().getTaskStateSnapshot());
+            assertEquals(subtaskStates, statefulExec3.getTaskRestore().getTaskStateSnapshot());
+            assertNull(statelessExec1.getTaskRestore());
+            assertNull(statelessExec2.getTaskRestore());
         } catch (Exception e) {
             e.printStackTrace();
             fail(e.getMessage());
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStatsTrackerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStatsTrackerTest.java
index fa92cfc..3c134cb 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStatsTrackerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStatsTrackerTest.java
@@ -21,6 +21,8 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.metrics.Gauge;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
@@ -75,9 +77,12 @@ public class CheckpointStatsTrackerTest {
     /** Tests that the number of remembered checkpoints configuration is respected. */
     @Test
     public void testTrackerWithoutHistory() throws Exception {
-        int numberOfSubtasks = 3;
-
-        JobVertexID vertexID = new JobVertexID();
+        JobVertexID jobVertexID = new JobVertexID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID, 3, 256)
+                        .build();
+        ExecutionJobVertex jobVertex = graph.getJobVertex(jobVertexID);
 
         CheckpointStatsTracker tracker =
                 new CheckpointStatsTracker(
@@ -91,11 +96,11 @@ public class CheckpointStatsTrackerTest {
                         1,
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-                        singletonMap(vertexID, numberOfSubtasks));
+                        singletonMap(jobVertexID, jobVertex.getParallelism()));
 
-        pending.reportSubtaskStats(vertexID, createSubtaskStats(0));
-        pending.reportSubtaskStats(vertexID, createSubtaskStats(1));
-        pending.reportSubtaskStats(vertexID, createSubtaskStats(2));
+        pending.reportSubtaskStats(jobVertexID, createSubtaskStats(0));
+        pending.reportSubtaskStats(jobVertexID, createSubtaskStats(1));
+        pending.reportSubtaskStats(jobVertexID, createSubtaskStats(2));
 
         pending.reportCompletedCheckpoint(null);
 
@@ -121,10 +126,14 @@ public class CheckpointStatsTrackerTest {
     /** Tests tracking of checkpoints. */
     @Test
     public void testCheckpointTracking() throws Exception {
-        int numberOfSubtasks = 3;
-
-        JobVertexID vertexID = new JobVertexID();
-        Map<JobVertexID, Integer> vertexToDop = singletonMap(vertexID, numberOfSubtasks);
+        JobVertexID jobVertexID = new JobVertexID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID, 3, 256)
+                        .build();
+        ExecutionJobVertex jobVertex = graph.getJobVertex(jobVertexID);
+        Map<JobVertexID, Integer> vertexToDop =
+                singletonMap(jobVertexID, jobVertex.getParallelism());
 
         CheckpointStatsTracker tracker =
                 new CheckpointStatsTracker(
@@ -141,9 +150,9 @@ public class CheckpointStatsTrackerTest {
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
                         vertexToDop);
 
-        completed1.reportSubtaskStats(vertexID, createSubtaskStats(0));
-        completed1.reportSubtaskStats(vertexID, createSubtaskStats(1));
-        completed1.reportSubtaskStats(vertexID, createSubtaskStats(2));
+        completed1.reportSubtaskStats(jobVertexID, createSubtaskStats(0));
+        completed1.reportSubtaskStats(jobVertexID, createSubtaskStats(1));
+        completed1.reportSubtaskStats(jobVertexID, createSubtaskStats(2));
 
         completed1.reportCompletedCheckpoint(null);
 
@@ -163,9 +172,9 @@ public class CheckpointStatsTrackerTest {
                 tracker.reportPendingCheckpoint(
                         2, 1, CheckpointProperties.forSavepoint(true), vertexToDop);
 
-        savepoint.reportSubtaskStats(vertexID, createSubtaskStats(0));
-        savepoint.reportSubtaskStats(vertexID, createSubtaskStats(1));
-        savepoint.reportSubtaskStats(vertexID, createSubtaskStats(2));
+        savepoint.reportSubtaskStats(jobVertexID, createSubtaskStats(0));
+        savepoint.reportSubtaskStats(jobVertexID, createSubtaskStats(1));
+        savepoint.reportSubtaskStats(jobVertexID, createSubtaskStats(2));
 
         savepoint.reportCompletedCheckpoint(null);
 
@@ -243,8 +252,7 @@ public class CheckpointStatsTrackerTest {
     /** Tests that snapshots are only created if a new snapshot has been reported or updated. */
     @Test
     public void testCreateSnapshot() throws Exception {
-        JobVertexID jobVertexId = new JobVertexID();
-
+        JobVertexID jobVertexID = new JobVertexID();
         CheckpointStatsTracker tracker =
                 new CheckpointStatsTracker(
                         10,
@@ -260,9 +268,9 @@ public class CheckpointStatsTrackerTest {
                         1,
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-                        singletonMap(jobVertexId, 1));
+                        singletonMap(jobVertexID, 1));
 
-        pending.reportSubtaskStats(jobVertexId, createSubtaskStats(0));
+        pending.reportSubtaskStats(jobVertexID, createSubtaskStats(0));
 
         CheckpointStatsSnapshot snapshot2 = tracker.createSnapshot();
         assertNotEquals(snapshot1, snapshot2);
@@ -345,7 +353,12 @@ public class CheckpointStatsTrackerTest {
                     }
                 };
 
-        JobVertexID vertexID = new JobVertexID();
+        JobVertexID jobVertexID = new JobVertexID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(jobVertexID)
+                        .build();
+        ExecutionJobVertex jobVertex = graph.getJobVertex(jobVertexID);
 
         CheckpointStatsTracker stats =
                 new CheckpointStatsTracker(
@@ -415,7 +428,7 @@ public class CheckpointStatsTrackerTest {
                         0,
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-                        singletonMap(vertexID, 1));
+                        singletonMap(jobVertexID, 1));
 
         // Check counts
         assertEquals(Long.valueOf(1), numCheckpoints.getValue());
@@ -444,7 +457,7 @@ public class CheckpointStatsTrackerTest {
                         false,
                         true);
 
-        assertTrue(pending.reportSubtaskStats(vertexID, subtaskStats));
+        assertTrue(pending.reportSubtaskStats(jobVertexID, subtaskStats));
 
         pending.reportCompletedCheckpoint(externalPath);
 
@@ -467,7 +480,7 @@ public class CheckpointStatsTrackerTest {
                         11,
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-                        singletonMap(vertexID, 1));
+                        singletonMap(jobVertexID, 1));
 
         long failureTimestamp = 1230123L;
         nextPending.reportFailedCheckpoint(failureTimestamp, null);
@@ -503,9 +516,9 @@ public class CheckpointStatsTrackerTest {
                         5000,
                         CheckpointProperties.forCheckpoint(
                                 CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-                        singletonMap(vertexID, 1));
+                        singletonMap(jobVertexID, 1));
 
-        thirdPending.reportSubtaskStats(vertexID, subtaskStats);
+        thirdPending.reportSubtaskStats(jobVertexID, subtaskStats);
         thirdPending.reportCompletedCheckpoint(null);
 
         // Verify external path is "n/a", because internal checkpoint won't generate external path.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest.java
new file mode 100644
index 0000000..849de56
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/DefaultCheckpointPlanCalculatorTest.java
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.testtasks.NoOpInvokable;
+
+import org.hamcrest.CoreMatchers;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.ExecutionException;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.hamcrest.Matchers.containsInAnyOrder;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.fail;
+
+/**
+ * Declarative tests for {@link DefaultCheckpointPlanCalculator}.
+ *
+ * <p>This test contains a framework for declaring vertex and edge states to then assert the
+ * calculator behavior.
+ */
+public class DefaultCheckpointPlanCalculatorTest {
+
+    @Test
+    public void testComputeAllRunningGraph() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(3, Collections.emptySet()),
+                        new VertexDeclaration(4, Collections.emptySet()),
+                        new VertexDeclaration(5, Collections.emptySet()),
+                        new VertexDeclaration(6, Collections.emptySet())),
+                Arrays.asList(
+                        new EdgeDeclaration(0, 2, DistributionPattern.ALL_TO_ALL),
+                        new EdgeDeclaration(1, 2, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(2, 3, DistributionPattern.ALL_TO_ALL)),
+                Arrays.asList(
+                        new TaskDeclaration(0, range(0, 3)), new TaskDeclaration(1, range(0, 4))));
+    }
+
+    @Test
+    public void testAllToAllEdgeWithSomeSourcesFinished() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(3, range(0, 2)),
+                        new VertexDeclaration(4, Collections.emptySet())),
+                Collections.singletonList(
+                        new EdgeDeclaration(0, 1, DistributionPattern.ALL_TO_ALL)),
+                Collections.singletonList(new TaskDeclaration(0, range(2, 3))));
+    }
+
+    @Test
+    public void testOneToOneEdgeWithSomeSourcesFinished() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(4, range(0, 2)),
+                        new VertexDeclaration(4, Collections.emptySet())),
+                Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)),
+                Arrays.asList(
+                        new TaskDeclaration(0, range(2, 4)), new TaskDeclaration(1, range(0, 2))));
+    }
+
+    @Test
+    public void testOneToOnEdgeWithSomeSourcesAndTargetsFinished() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(4, range(0, 2)), new VertexDeclaration(4, of(0))),
+                Collections.singletonList(new EdgeDeclaration(0, 1, DistributionPattern.POINTWISE)),
+                Arrays.asList(
+                        new TaskDeclaration(0, range(2, 4)), new TaskDeclaration(1, range(1, 2))));
+    }
+
+    @Test
+    public void testComputeWithMultipleInputs() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(3, range(0, 3)),
+                        new VertexDeclaration(5, of(0, 2, 3)),
+                        new VertexDeclaration(5, of(2, 4)),
+                        new VertexDeclaration(5, of(2))),
+                Arrays.asList(
+                        new EdgeDeclaration(0, 3, DistributionPattern.ALL_TO_ALL),
+                        new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(2, 3, DistributionPattern.POINTWISE)),
+                Arrays.asList(
+                        new TaskDeclaration(1, of(1, 4)), new TaskDeclaration(2, of(0, 1, 3))));
+    }
+
+    @Test
+    public void testComputeWithMultipleLevels() throws Exception {
+        runSingleTest(
+                Arrays.asList(
+                        new VertexDeclaration(16, range(0, 4)),
+                        new VertexDeclaration(16, range(0, 16)),
+                        new VertexDeclaration(16, range(0, 2)),
+                        new VertexDeclaration(16, Collections.emptySet()),
+                        new VertexDeclaration(16, Collections.emptySet())),
+                Arrays.asList(
+                        new EdgeDeclaration(0, 2, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(0, 3, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(1, 2, DistributionPattern.ALL_TO_ALL),
+                        new EdgeDeclaration(1, 3, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(2, 4, DistributionPattern.POINTWISE),
+                        new EdgeDeclaration(3, 4, DistributionPattern.ALL_TO_ALL)),
+                Arrays.asList(
+                        new TaskDeclaration(0, range(4, 16)),
+                        new TaskDeclaration(2, range(2, 4)),
+                        new TaskDeclaration(3, range(0, 4))));
+    }
+
+    @Test
+    public void testWithTriggeredTasksNotRunning() throws Exception {
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .setTransitToRunning(false)
+                        .build();
+        DefaultCheckpointPlanCalculator checkpointPlanCalculator =
+                createCheckpointPlanCalculator(graph);
+
+        try {
+            checkpointPlanCalculator.calculateCheckpointPlan().get();
+            fail("The computation should fail since not all tasks to trigger have start running");
+        } catch (ExecutionException e) {
+            Throwable cause = e.getCause();
+            assertThat(cause, instanceOf(CheckpointException.class));
+            assertEquals(
+                    CheckpointFailureReason.NOT_ALL_REQUIRED_TASKS_RUNNING,
+                    ((CheckpointException) cause).getCheckpointFailureReason());
+        }
+    }
+
+    // ------------------------- Utility methods ---------------------------------------
+
+    private void runSingleTest(
+            List<VertexDeclaration> vertexDeclarations,
+            List<EdgeDeclaration> edgeDeclarations,
+            List<TaskDeclaration> expectedToTriggerTaskDeclarations)
+            throws Exception {
+        runSingleTest(
+                vertexDeclarations,
+                edgeDeclarations,
+                expectedToTriggerTaskDeclarations,
+                IntStream.range(0, vertexDeclarations.size())
+                        .mapToObj(
+                                i ->
+                                        new TaskDeclaration(
+                                                i,
+                                                vertexDeclarations.get(i).finishedSubtaskIndices))
+                        .collect(Collectors.toList()));
+    }
+
+    private void runSingleTest(
+            List<VertexDeclaration> vertexDeclarations,
+            List<EdgeDeclaration> edgeDeclarations,
+            List<TaskDeclaration> expectedToTriggerTaskDeclarations,
+            List<TaskDeclaration> expectedFinishedTaskDeclarations)
+            throws Exception {
+
+        ExecutionGraph graph = createExecutionGraph(vertexDeclarations, edgeDeclarations);
+        DefaultCheckpointPlanCalculator planCalculator = createCheckpointPlanCalculator(graph);
+
+        List<TaskDeclaration> expectedRunningTaskDeclarations = new ArrayList<>();
+        List<ExecutionJobVertex> expectedFullyFinishedJobVertices = new ArrayList<>();
+
+        expectedFinishedTaskDeclarations.forEach(
+                finishedDeclaration -> {
+                    ExecutionJobVertex jobVertex =
+                            chooseJobVertex(graph, finishedDeclaration.vertexIndex);
+                    expectedRunningTaskDeclarations.add(
+                            new TaskDeclaration(
+                                    finishedDeclaration.vertexIndex,
+                                    minus(
+                                            range(0, jobVertex.getParallelism()),
+                                            finishedDeclaration.subtaskIndices)));
+                    if (finishedDeclaration.subtaskIndices.size() == jobVertex.getParallelism()) {
+                        expectedFullyFinishedJobVertices.add(jobVertex);
+                    }
+                });
+
+        List<ExecutionVertex> expectedRunningTasks =
+                chooseTasks(graph, expectedRunningTaskDeclarations.toArray(new TaskDeclaration[0]));
+        List<Execution> expectedFinishedTasks =
+                chooseTasks(graph, expectedFinishedTaskDeclarations.toArray(new TaskDeclaration[0]))
+                        .stream()
+                        .map(ExecutionVertex::getCurrentExecutionAttempt)
+                        .collect(Collectors.toList());
+        List<ExecutionVertex> expectedToTriggerTasks =
+                chooseTasks(
+                        graph, expectedToTriggerTaskDeclarations.toArray(new TaskDeclaration[0]));
+
+        // Tests computing checkpoint plan
+        CheckpointPlan checkpointPlan = planCalculator.calculateCheckpointPlan().get();
+        checkCheckpointPlan(
+                expectedToTriggerTasks,
+                expectedRunningTasks,
+                expectedFinishedTasks,
+                expectedFullyFinishedJobVertices,
+                checkpointPlan);
+    }
+
+    private ExecutionGraph createExecutionGraph(
+            List<VertexDeclaration> vertexDeclarations, List<EdgeDeclaration> edgeDeclarations)
+            throws Exception {
+
+        JobVertex[] jobVertices = new JobVertex[vertexDeclarations.size()];
+        for (int i = 0; i < vertexDeclarations.size(); ++i) {
+            jobVertices[i] =
+                    ExecutionGraphTestUtils.createJobVertex(
+                            vertexName(i),
+                            vertexDeclarations.get(i).parallelism,
+                            NoOpInvokable.class);
+        }
+
+        for (EdgeDeclaration edgeDeclaration : edgeDeclarations) {
+            jobVertices[edgeDeclaration.target].connectNewDataSetAsInput(
+                    jobVertices[edgeDeclaration.source],
+                    edgeDeclaration.distributionPattern,
+                    ResultPartitionType.PIPELINED);
+        }
+
+        ExecutionGraph graph = ExecutionGraphTestUtils.createSimpleTestGraph(jobVertices);
+        graph.start(ComponentMainThreadExecutorServiceAdapter.forMainThread());
+        graph.transitionToRunning();
+        graph.getAllExecutionVertices()
+                .forEach(
+                        task ->
+                                task.getCurrentExecutionAttempt()
+                                        .transitionState(ExecutionState.RUNNING));
+
+        for (int i = 0; i < vertexDeclarations.size(); ++i) {
+            JobVertexID jobVertexId = jobVertices[i].getID();
+            vertexDeclarations
+                    .get(i)
+                    .finishedSubtaskIndices
+                    .forEach(
+                            index -> {
+                                graph.getJobVertex(jobVertexId)
+                                        .getTaskVertices()[index]
+                                        .getCurrentExecutionAttempt()
+                                        .markFinished();
+                            });
+        }
+
+        return graph;
+    }
+
+    private DefaultCheckpointPlanCalculator createCheckpointPlanCalculator(ExecutionGraph graph) {
+        DefaultCheckpointPlanCalculator checkpointPlanCalculator =
+                new DefaultCheckpointPlanCalculator(
+                        graph.getJobID(),
+                        new ExecutionGraphCheckpointPlanCalculatorContext(graph),
+                        graph.getVerticesTopologically());
+        checkpointPlanCalculator.setAllowCheckpointsAfterTasksFinished(true);
+        return checkpointPlanCalculator;
+    }
+
+    private void checkCheckpointPlan(
+            List<ExecutionVertex> expectedToTrigger,
+            List<ExecutionVertex> expectedRunning,
+            List<Execution> expectedFinished,
+            List<ExecutionJobVertex> expectedFullyFinished,
+            CheckpointPlan plan) {
+
+        // Compares tasks to trigger
+        List<Execution> expectedTriggeredExecutions =
+                expectedToTrigger.stream()
+                        .map(ExecutionVertex::getCurrentExecutionAttempt)
+                        .collect(Collectors.toList());
+        assertSameInstancesWithoutOrder(
+                "The computed tasks to trigger is different from expected",
+                expectedTriggeredExecutions,
+                plan.getTasksToTrigger());
+
+        // Compares running tasks
+        assertSameInstancesWithoutOrder(
+                "The computed running tasks is different from expected",
+                expectedRunning,
+                plan.getTasksToCommitTo());
+
+        // Compares finished tasks
+        assertSameInstancesWithoutOrder(
+                "The computed finished tasks is different from expected",
+                expectedFinished,
+                plan.getFinishedTasks());
+
+        // Compares fully finished job vertices
+        assertSameInstancesWithoutOrder(
+                "The computed fully finished JobVertex is different from expected",
+                expectedFullyFinished,
+                plan.getFullyFinishedJobVertex());
+
+        // Compares tasks to ack
+        assertSameInstancesWithoutOrder(
+                "The computed tasks to ack is different from expected",
+                expectedRunning.stream()
+                        .map(ExecutionVertex::getCurrentExecutionAttempt)
+                        .collect(Collectors.toList()),
+                plan.getTasksToWaitFor());
+    }
+
+    private <T> void assertSameInstancesWithoutOrder(
+            String comment, Collection<T> expected, Collection<T> actual) {
+        assertThat(
+                comment,
+                expected,
+                containsInAnyOrder(
+                        actual.stream()
+                                .map(CoreMatchers::sameInstance)
+                                .collect(Collectors.toList())));
+    }
+
+    private List<ExecutionVertex> chooseTasks(
+            ExecutionGraph graph, TaskDeclaration... chosenDeclarations) {
+        List<ExecutionVertex> tasks = new ArrayList<>();
+
+        for (TaskDeclaration chosenDeclaration : chosenDeclarations) {
+            ExecutionJobVertex jobVertex = chooseJobVertex(graph, chosenDeclaration.vertexIndex);
+            chosenDeclaration.subtaskIndices.forEach(
+                    index -> tasks.add(jobVertex.getTaskVertices()[index]));
+        }
+
+        return tasks;
+    }
+
+    private ExecutionJobVertex chooseJobVertex(ExecutionGraph graph, int vertexIndex) {
+        String name = vertexName(vertexIndex);
+        Optional<ExecutionJobVertex> foundVertex =
+                graph.getAllVertices().values().stream()
+                        .filter(jobVertex -> jobVertex.getName().equals(name))
+                        .findFirst();
+
+        if (!foundVertex.isPresent()) {
+            throw new RuntimeException("Vertex not found with index " + vertexIndex);
+        }
+
+        return foundVertex.get();
+    }
+
+    private String vertexName(int index) {
+        return "vertex_" + index;
+    }
+
+    private Set<Integer> range(int start, int end) {
+        return IntStream.range(start, end).boxed().collect(Collectors.toSet());
+    }
+
+    private Set<Integer> of(Integer... index) {
+        return new HashSet<>(Arrays.asList(index));
+    }
+
+    private Set<Integer> minus(Set<Integer> all, Set<Integer> toMinus) {
+        return all.stream().filter(e -> !toMinus.contains(e)).collect(Collectors.toSet());
+    }
+
+    // ------------------------- Utility helper classes ---------------------------------------
+
+    private static class VertexDeclaration {
+        final int parallelism;
+        final Set<Integer> finishedSubtaskIndices;
+
+        public VertexDeclaration(int parallelism, Set<Integer> finishedSubtaskIndices) {
+            this.parallelism = parallelism;
+            this.finishedSubtaskIndices = finishedSubtaskIndices;
+        }
+    }
+
+    private static class EdgeDeclaration {
+        final int source;
+        final int target;
+        final DistributionPattern distributionPattern;
+
+        public EdgeDeclaration(int source, int target, DistributionPattern distributionPattern) {
+            this.source = source;
+            this.target = target;
+            this.distributionPattern = distributionPattern;
+        }
+    }
+
+    private static class TaskDeclaration {
+        final int vertexIndex;
+
+        final Set<Integer> subtaskIndices;
+
+        public TaskDeclaration(int vertexIndex, Set<Integer> subtaskIndices) {
+            this.vertexIndex = vertexIndex;
+            this.subtaskIndices = subtaskIndices;
+        }
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
index 10fad28..9396ef0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
@@ -18,13 +18,12 @@
 
 package org.apache.flink.runtime.checkpoint;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.Execution;
-import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
-import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionGraphCheckpointPlanCalculatorContext;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
@@ -32,7 +31,6 @@ import org.apache.flink.util.TestLogger;
 
 import org.junit.Before;
 import org.junit.Test;
-import org.mockito.Mockito;
 
 import java.util.Collections;
 import java.util.concurrent.ThreadLocalRandom;
@@ -40,7 +38,6 @@ import java.util.concurrent.ThreadLocalRandom;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
-import static org.powermock.api.mockito.PowerMockito.when;
 
 /** Tests for actions of {@link CheckpointCoordinator} on task failures. */
 public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
@@ -56,10 +53,13 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
      * on job failover could handle the {@code currentPeriodicTrigger} null case well.
      */
     @Test
-    public void testAbortPendingCheckpointsWithTriggerValidation() {
+    public void testAbortPendingCheckpointsWithTriggerValidation() throws Exception {
         final int maxConcurrentCheckpoints = ThreadLocalRandom.current().nextInt(10) + 1;
-        ExecutionVertex executionVertex = mockExecutionVertex();
-        JobID jobId = new JobID();
+        ExecutionGraph graph =
+                new CheckpointCoordinatorTestingUtils.CheckpointExecutionGraphBuilder()
+                        .addJobVertex(new JobVertexID())
+                        .setTransitToRunning(false)
+                        .build();
         CheckpointCoordinatorConfiguration checkpointCoordinatorConfiguration =
                 new CheckpointCoordinatorConfiguration(
                         Integer.MAX_VALUE,
@@ -73,7 +73,7 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
                         0);
         CheckpointCoordinator checkpointCoordinator =
                 new CheckpointCoordinator(
-                        jobId,
+                        graph.getJobID(),
                         checkpointCoordinatorConfiguration,
                         Collections.emptyList(),
                         new StandaloneCheckpointIDCounter(),
@@ -84,16 +84,19 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
                         manualThreadExecutor,
                         SharedStateRegistry.DEFAULT_FACTORY,
                         mock(CheckpointFailureManager.class),
-                        new CheckpointPlanCalculator(
-                                jobId,
-                                Collections.singletonList(executionVertex),
-                                Collections.singletonList(executionVertex),
-                                Collections.singletonList(executionVertex)),
-                        new ExecutionAttemptMappingProvider(
-                                Collections.singletonList(executionVertex)));
+                        new DefaultCheckpointPlanCalculator(
+                                graph.getJobID(),
+                                new ExecutionGraphCheckpointPlanCalculatorContext(graph),
+                                graph.getVerticesTopologically()),
+                        new ExecutionAttemptMappingProvider(graph.getAllExecutionVertices()));
 
         // switch current execution's state to running to allow checkpoint could be triggered.
-        mockExecutionRunning(executionVertex);
+        graph.transitionToRunning();
+        graph.getAllExecutionVertices()
+                .forEach(
+                        task ->
+                                task.getCurrentExecutionAttempt()
+                                        .transitionState(ExecutionState.RUNNING));
 
         checkpointCoordinator.startCheckpointScheduler();
         assertTrue(checkpointCoordinator.isCurrentPeriodicTriggerAvailable());
@@ -124,18 +127,4 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
         assertTrue(checkpointCoordinator.isCurrentPeriodicTriggerAvailable());
         assertEquals(0, checkpointCoordinator.getNumberOfPendingCheckpoints());
     }
-
-    private ExecutionVertex mockExecutionVertex() {
-        ExecutionAttemptID executionAttemptID = new ExecutionAttemptID();
-        ExecutionVertex executionVertex = mock(ExecutionVertex.class);
-        Execution execution = Mockito.mock(Execution.class);
-        when(execution.getAttemptId()).thenReturn(executionAttemptID);
-        when(executionVertex.getCurrentExecutionAttempt()).thenReturn(execution);
-        return executionVertex;
-    }
-
-    private void mockExecutionRunning(ExecutionVertex executionVertex) {
-        when(executionVertex.getCurrentExecutionAttempt().getState())
-                .thenReturn(ExecutionState.RUNNING);
-    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index 56c8aca..9069333 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.Str
 import org.apache.flink.runtime.checkpoint.PendingCheckpoint.TaskAcknowledgeResult;
 import org.apache.flink.runtime.checkpoint.hooks.MasterHooks;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
@@ -56,7 +57,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
@@ -82,7 +82,7 @@ import static org.powermock.api.mockito.PowerMockito.when;
 /** Tests for the {@link PendingCheckpoint}. */
 public class PendingCheckpointTest {
 
-    private static final Map<ExecutionAttemptID, ExecutionVertex> ACK_TASKS = new HashMap<>();
+    private static final List<Execution> ACK_TASKS = new ArrayList<>();
     private static final List<ExecutionVertex> TASKS_TO_COMMIT = new ArrayList<>();
     private static final ExecutionAttemptID ATTEMPT_ID = new ExecutionAttemptID();
 
@@ -97,7 +97,11 @@ public class PendingCheckpointTest {
         when(vertex.getMaxParallelism()).thenReturn(128);
         when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(1);
         when(vertex.getJobVertex()).thenReturn(jobVertex);
-        ACK_TASKS.put(ATTEMPT_ID, vertex);
+
+        Execution execution = mock(Execution.class);
+        when(execution.getAttemptId()).thenReturn(ATTEMPT_ID);
+        when(execution.getVertex()).thenReturn(vertex);
+        ACK_TASKS.add(execution);
         TASKS_TO_COMMIT.add(vertex);
     }
 
@@ -603,14 +607,19 @@ public class PendingCheckpointTest {
                         1024,
                         4096);
 
-        final Map<ExecutionAttemptID, ExecutionVertex> ackTasks = new HashMap<>(ACK_TASKS);
+        final List<Execution> ackTasks = new ArrayList<>(ACK_TASKS);
         final List<ExecutionVertex> tasksToCommit = new ArrayList<>(TASKS_TO_COMMIT);
 
         return new PendingCheckpoint(
                 new JobID(),
                 0,
                 1,
-                new CheckpointPlan(Collections.emptyList(), ackTasks, tasksToCommit),
+                new CheckpointPlan(
+                        Collections.emptyList(),
+                        ackTasks,
+                        tasksToCommit,
+                        Collections.emptyList(),
+                        Collections.emptyList()),
                 operatorCoordinators,
                 masterStateIdentifiers,
                 props,