You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2019/10/28 08:22:30 UTC

[flink] branch master updated (8543334 -> beb3fb0)

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 8543334  [hotfix][test] Clean up the test code in SourceStreamTaskTest and OneInputStreamTaskTest
     new 126cff6  [hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle
     new 22c3248  [hotfix] Split too large file CheckpointCoordinatorTest.java into several small files
     new 3c84c05  [hotfix] Correct code style of CheckpointCoordinator
     new 5ab6261  [FLINK-13904][checkpointing] Make trigger thread of CheckpointCoordinator single-threaded
     new fc19673  [FLINK-13904][tests] Support checkpoint consumer of SimpleAckingTaskManagerGateway
     new 41cda38  [FLINK-13904][checkpointing] Avoid competition between savepoint and periodic checkpoint triggering
     new 00fe29b  [FLINK-13904][checkpointing] Remove trigger lock of CheckpointCoordinator
     new beb3fb0  [FLINK-13904][checkpointing] Encapsule and optimize the time relevant operation of CheckpointCoordinator

The 8 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../runtime/checkpoint/CheckpointCoordinator.java  |  360 ++--
 .../runtime/executiongraph/ExecutionGraph.java     |   19 +
 .../CheckpointCoordinatorFailureTest.java          |   12 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |   40 +-
 .../CheckpointCoordinatorRestoringTest.java        | 1014 ++++++++++
 .../checkpoint/CheckpointCoordinatorTest.java      | 2065 +++-----------------
 .../CheckpointCoordinatorTestingUtils.java         |  569 ++++++
 .../CheckpointCoordinatorTriggeringTest.java       |  313 +++
 .../checkpoint/CheckpointStateRestoreTest.java     |   12 +-
 .../FailoverStrategyCheckpointCoordinatorTest.java |   62 +-
 .../ManuallyTriggeredScheduledExecutor.java        |  123 +-
 .../utils/SimpleAckingTaskManagerGateway.java      |   37 +-
 .../runtime/messages/CheckpointMessagesTest.java   |   13 +-
 .../runtime/util/TestingScheduledExecutor.java     |   62 +
 14 files changed, 2645 insertions(+), 2056 deletions(-)
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java


[flink] 07/08: [FLINK-13904][checkpointing] Remove trigger lock of CheckpointCoordinator

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 00fe29b1be59c9ee6d7d6f3f71a4759f5132e12d
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Mon Sep 30 11:17:39 2019 +0800

    [FLINK-13904][checkpointing] Remove trigger lock of CheckpointCoordinator
    
    Now the checkpoint and savepoint triggering are executed in the one thread without any competition.
    So the triggerLock is no longer needed.
---
 .../runtime/checkpoint/CheckpointCoordinator.java  | 228 ++++++++++-----------
 1 file changed, 107 insertions(+), 121 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index b4843e4..7517a38 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -87,13 +87,6 @@ public class CheckpointCoordinator {
 	/** Coordinator-wide lock to safeguard the checkpoint updates. */
 	private final Object lock = new Object();
 
-	/** Lock specially to make sure that trigger requests do not overtake each other.
-	 * This is not done with the coordinator-wide lock, because as part of triggering,
-	 * blocking operations may happen (distributed atomic counters).
-	 * Using a dedicated lock, we avoid blocking the processing of 'acknowledge/decline'
-	 * messages during that phase. */
-	private final Object triggerLock = new Object();
-
 	/** The job whose checkpoint this coordinator coordinates. */
 	private final JobID job;
 
@@ -537,149 +530,142 @@ public class CheckpointCoordinator {
 
 		// we will actually trigger this checkpoint!
 
-		// we lock with a special lock to make sure that trigger requests do not overtake each other.
-		// this is not done with the coordinator-wide lock, because the 'checkpointIdCounter'
-		// may issue blocking operations. Using a different lock than the coordinator-wide lock,
-		// we avoid blocking the processing of 'acknowledge/decline' messages during that time.
-		synchronized (triggerLock) {
-
-			final CheckpointStorageLocation checkpointStorageLocation;
-			final long checkpointID;
+		final CheckpointStorageLocation checkpointStorageLocation;
+		final long checkpointID;
 
-			try {
-				// this must happen outside the coordinator-wide lock, because it communicates
-				// with external services (in HA mode) and may block for a while.
-				checkpointID = checkpointIdCounter.getAndIncrement();
+		try {
+			// this must happen outside the coordinator-wide lock, because it communicates
+			// with external services (in HA mode) and may block for a while.
+			checkpointID = checkpointIdCounter.getAndIncrement();
 
-				checkpointStorageLocation = props.isSavepoint() ?
-						checkpointStorage.initializeLocationForSavepoint(checkpointID, externalSavepointLocation) :
-						checkpointStorage.initializeLocationForCheckpoint(checkpointID);
-			}
-			catch (Throwable t) {
-				int numUnsuccessful = numUnsuccessfulCheckpointsTriggers.incrementAndGet();
-				LOG.warn("Failed to trigger checkpoint for job {} ({} consecutive failed attempts so far).",
-						job,
-						numUnsuccessful,
-						t);
-				throw new CheckpointException(CheckpointFailureReason.EXCEPTION, t);
-			}
+			checkpointStorageLocation = props.isSavepoint() ?
+					checkpointStorage.initializeLocationForSavepoint(checkpointID, externalSavepointLocation) :
+					checkpointStorage.initializeLocationForCheckpoint(checkpointID);
+		}
+		catch (Throwable t) {
+			int numUnsuccessful = numUnsuccessfulCheckpointsTriggers.incrementAndGet();
+			LOG.warn("Failed to trigger checkpoint for job {} ({} consecutive failed attempts so far).",
+					job,
+					numUnsuccessful,
+					t);
+			throw new CheckpointException(CheckpointFailureReason.EXCEPTION, t);
+		}
 
-			final PendingCheckpoint checkpoint = new PendingCheckpoint(
-				job,
+		final PendingCheckpoint checkpoint = new PendingCheckpoint(
+			job,
+			checkpointID,
+			timestamp,
+			ackTasks,
+			props,
+			checkpointStorageLocation,
+			executor);
+
+		if (statsTracker != null) {
+			PendingCheckpointStats callback = statsTracker.reportPendingCheckpoint(
 				checkpointID,
 				timestamp,
-				ackTasks,
-				props,
-				checkpointStorageLocation,
-				executor);
-
-			if (statsTracker != null) {
-				PendingCheckpointStats callback = statsTracker.reportPendingCheckpoint(
-					checkpointID,
-					timestamp,
-					props);
+				props);
 
-				checkpoint.setStatsCallback(callback);
-			}
+			checkpoint.setStatsCallback(callback);
+		}
 
-			// schedule the timer that will clean up the expired checkpoints
-			final Runnable canceller = () -> {
-				synchronized (lock) {
-					// only do the work if the checkpoint is not discarded anyways
-					// note that checkpoint completion discards the pending checkpoint object
-					if (!checkpoint.isDiscarded()) {
-						LOG.info("Checkpoint {} of job {} expired before completing.", checkpointID, job);
+		// schedule the timer that will clean up the expired checkpoints
+		final Runnable canceller = () -> {
+			synchronized (lock) {
+				// only do the work if the checkpoint is not discarded anyways
+				// note that checkpoint completion discards the pending checkpoint object
+				if (!checkpoint.isDiscarded()) {
+					LOG.info("Checkpoint {} of job {} expired before completing.", checkpointID, job);
 
-						failPendingCheckpoint(checkpoint, CheckpointFailureReason.CHECKPOINT_EXPIRED);
-						pendingCheckpoints.remove(checkpointID);
-						rememberRecentCheckpointId(checkpointID);
+					failPendingCheckpoint(checkpoint, CheckpointFailureReason.CHECKPOINT_EXPIRED);
+					pendingCheckpoints.remove(checkpointID);
+					rememberRecentCheckpointId(checkpointID);
 
-						triggerQueuedRequests();
-					}
+					triggerQueuedRequests();
 				}
-			};
+			}
+		};
 
-			try {
-				// re-acquire the coordinator-wide lock
-				synchronized (lock) {
-					// since we released the lock in the meantime, we need to re-check
-					// that the conditions still hold.
-					if (shutdown) {
-						throw new CheckpointException(CheckpointFailureReason.CHECKPOINT_COORDINATOR_SHUTDOWN);
+		try {
+			// re-acquire the coordinator-wide lock
+			synchronized (lock) {
+				// since we released the lock in the meantime, we need to re-check
+				// that the conditions still hold.
+				if (shutdown) {
+					throw new CheckpointException(CheckpointFailureReason.CHECKPOINT_COORDINATOR_SHUTDOWN);
+				}
+				else if (!props.forceCheckpoint()) {
+					if (triggerRequestQueued) {
+						LOG.warn("Trying to trigger another checkpoint for job {} while one was queued already.", job);
+						throw new CheckpointException(CheckpointFailureReason.ALREADY_QUEUED);
 					}
-					else if (!props.forceCheckpoint()) {
-						if (triggerRequestQueued) {
-							LOG.warn("Trying to trigger another checkpoint for job {} while one was queued already.", job);
-							throw new CheckpointException(CheckpointFailureReason.ALREADY_QUEUED);
-						}
 
-						checkConcurrentCheckpoints();
+					checkConcurrentCheckpoints();
 
-						checkMinPauseBetweenCheckpoints();
-					}
+					checkMinPauseBetweenCheckpoints();
+				}
 
-					LOG.info("Triggering checkpoint {} @ {} for job {}.", checkpointID, timestamp, job);
+				LOG.info("Triggering checkpoint {} @ {} for job {}.", checkpointID, timestamp, job);
 
-					pendingCheckpoints.put(checkpointID, checkpoint);
+				pendingCheckpoints.put(checkpointID, checkpoint);
 
-					ScheduledFuture<?> cancellerHandle = timer.schedule(
-							canceller,
-							checkpointTimeout, TimeUnit.MILLISECONDS);
+				ScheduledFuture<?> cancellerHandle = timer.schedule(
+						canceller,
+						checkpointTimeout, TimeUnit.MILLISECONDS);
 
-					if (!checkpoint.setCancellerHandle(cancellerHandle)) {
-						// checkpoint is already disposed!
-						cancellerHandle.cancel(false);
-					}
+				if (!checkpoint.setCancellerHandle(cancellerHandle)) {
+					// checkpoint is already disposed!
+					cancellerHandle.cancel(false);
+				}
 
-					// trigger the master hooks for the checkpoint
-					final List<MasterState> masterStates = MasterHooks.triggerMasterHooks(masterHooks.values(),
-							checkpointID, timestamp, executor, Time.milliseconds(checkpointTimeout));
-					for (MasterState s : masterStates) {
-						checkpoint.addMasterState(s);
-					}
+				// trigger the master hooks for the checkpoint
+				final List<MasterState> masterStates = MasterHooks.triggerMasterHooks(masterHooks.values(),
+						checkpointID, timestamp, executor, Time.milliseconds(checkpointTimeout));
+				for (MasterState s : masterStates) {
+					checkpoint.addMasterState(s);
 				}
-				// end of lock scope
+			}
+			// end of lock scope
 
-				final CheckpointOptions checkpointOptions = new CheckpointOptions(
-						props.getCheckpointType(),
-						checkpointStorageLocation.getLocationReference());
+			final CheckpointOptions checkpointOptions = new CheckpointOptions(
+					props.getCheckpointType(),
+					checkpointStorageLocation.getLocationReference());
 
-				// send the messages to the tasks that trigger their checkpoint
-				for (Execution execution: executions) {
-					if (props.isSynchronous()) {
-						execution.triggerSynchronousSavepoint(checkpointID, timestamp, checkpointOptions, advanceToEndOfTime);
-					} else {
-						execution.triggerCheckpoint(checkpointID, timestamp, checkpointOptions);
-					}
+			// send the messages to the tasks that trigger their checkpoint
+			for (Execution execution: executions) {
+				if (props.isSynchronous()) {
+					execution.triggerSynchronousSavepoint(checkpointID, timestamp, checkpointOptions, advanceToEndOfTime);
+				} else {
+					execution.triggerCheckpoint(checkpointID, timestamp, checkpointOptions);
 				}
-
-				numUnsuccessfulCheckpointsTriggers.set(0);
-				return checkpoint.getCompletionFuture();
 			}
-			catch (Throwable t) {
-				// guard the map against concurrent modifications
-				synchronized (lock) {
-					pendingCheckpoints.remove(checkpointID);
-				}
 
-				int numUnsuccessful = numUnsuccessfulCheckpointsTriggers.incrementAndGet();
-				LOG.warn("Failed to trigger checkpoint {} for job {}. ({} consecutive failed attempts so far)",
-						checkpointID, job, numUnsuccessful, t);
+			numUnsuccessfulCheckpointsTriggers.set(0);
+			return checkpoint.getCompletionFuture();
+		}
+		catch (Throwable t) {
+			// guard the map against concurrent modifications
+			synchronized (lock) {
+				pendingCheckpoints.remove(checkpointID);
+			}
 
-				if (!checkpoint.isDiscarded()) {
-					failPendingCheckpoint(checkpoint, CheckpointFailureReason.TRIGGER_CHECKPOINT_FAILURE, t);
-				}
+			int numUnsuccessful = numUnsuccessfulCheckpointsTriggers.incrementAndGet();
+			LOG.warn("Failed to trigger checkpoint {} for job {}. ({} consecutive failed attempts so far)",
+					checkpointID, job, numUnsuccessful, t);
 
-				try {
-					checkpointStorageLocation.disposeOnFailure();
-				}
-				catch (Throwable t2) {
-					LOG.warn("Cannot dispose failed checkpoint storage location {}", checkpointStorageLocation, t2);
-				}
+			if (!checkpoint.isDiscarded()) {
+				failPendingCheckpoint(checkpoint, CheckpointFailureReason.TRIGGER_CHECKPOINT_FAILURE, t);
+			}
 
-				throw new CheckpointException(CheckpointFailureReason.EXCEPTION, t);
+			try {
+				checkpointStorageLocation.disposeOnFailure();
 			}
-		} // end trigger lock
+			catch (Throwable t2) {
+				LOG.warn("Cannot dispose failed checkpoint storage location {}", checkpointStorageLocation, t2);
+			}
+
+			throw new CheckpointException(CheckpointFailureReason.EXCEPTION, t);
+		}
 	}
 
 	// --------------------------------------------------------------------------------------------


[flink] 06/08: [FLINK-13904][checkpointing] Avoid competition between savepoint and periodic checkpoint triggering

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 41cda38e17c1884c72b9bb092e6447ff40f2b5bd
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 25 21:00:52 2019 +0800

    [FLINK-13904][checkpointing] Avoid competition between savepoint and periodic checkpoint triggering
---
 .../runtime/checkpoint/CheckpointCoordinator.java  |  46 ++--
 .../CheckpointCoordinatorFailureTest.java          |  14 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |  43 +++-
 .../CheckpointCoordinatorRestoringTest.java        |  36 ++-
 .../checkpoint/CheckpointCoordinatorTest.java      | 285 +++++++++++----------
 .../CheckpointCoordinatorTestingUtils.java         | 224 ++++++++--------
 .../CheckpointCoordinatorTriggeringTest.java       |  62 +++--
 .../checkpoint/CheckpointStateRestoreTest.java     |  15 +-
 .../FailoverStrategyCheckpointCoordinatorTest.java |   5 +-
 .../ManuallyTriggeredScheduledExecutor.java        | 123 ++++++---
 .../utils/SimpleAckingTaskManagerGateway.java      |  34 ++-
 .../runtime/util/TestingScheduledExecutor.java     |   8 +-
 12 files changed, 509 insertions(+), 386 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index df9278e..b4843e4 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -408,46 +408,57 @@ public class CheckpointCoordinator {
 
 		checkNotNull(checkpointProperties);
 
-		try {
-			PendingCheckpoint pendingCheckpoint = triggerCheckpoint(
+		// TODO, call triggerCheckpoint directly after removing timer thread
+		// for now, execute the trigger in timer thread to avoid competition
+		final CompletableFuture<CompletedCheckpoint> resultFuture = new CompletableFuture<>();
+		timer.execute(() -> {
+			try {
+				triggerCheckpoint(
 					timestamp,
 					checkpointProperties,
 					targetLocation,
 					false,
-					advanceToEndOfEventTime);
-
-			return pendingCheckpoint.getCompletionFuture();
-		} catch (CheckpointException e) {
-			Throwable cause = new CheckpointException("Failed to trigger savepoint.", e.getCheckpointFailureReason());
-			return FutureUtils.completedExceptionally(cause);
-		}
+					advanceToEndOfEventTime).
+				whenComplete((completedCheckpoint, throwable) -> {
+					if (throwable == null) {
+						resultFuture.complete(completedCheckpoint);
+					} else {
+						resultFuture.completeExceptionally(throwable);
+					}
+				});
+			} catch (CheckpointException e) {
+				Throwable cause = new CheckpointException("Failed to trigger savepoint.", e.getCheckpointFailureReason());
+				resultFuture.completeExceptionally(cause);
+			}
+		});
+		return resultFuture;
 	}
 
 	/**
 	 * Triggers a new standard checkpoint and uses the given timestamp as the checkpoint
-	 * timestamp.
+	 * timestamp. The return value is a future. It completes when the checkpoint triggered finishes
+	 * or an error occurred.
 	 *
 	 * @param timestamp The timestamp for the checkpoint.
 	 * @param isPeriodic Flag indicating whether this triggered checkpoint is
 	 * periodic. If this flag is true, but the periodic scheduler is disabled,
 	 * the checkpoint will be declined.
-	 * @return <code>true</code> if triggering the checkpoint succeeded.
+	 * @return a future to the completed checkpoint.
 	 */
-	public boolean triggerCheckpoint(long timestamp, boolean isPeriodic) {
+	public CompletableFuture<CompletedCheckpoint> triggerCheckpoint(long timestamp, boolean isPeriodic) {
 		try {
-			triggerCheckpoint(timestamp, checkpointProperties, null, isPeriodic, false);
-			return true;
+			return triggerCheckpoint(timestamp, checkpointProperties, null, isPeriodic, false);
 		} catch (CheckpointException e) {
 			long latestGeneratedCheckpointId = getCheckpointIdCounter().get();
 			// here we can not get the failed pending checkpoint's id,
 			// so we pass the negative latest generated checkpoint id as a special flag
 			failureManager.handleJobLevelCheckpointException(e, -1 * latestGeneratedCheckpointId);
-			return false;
+			return FutureUtils.completedExceptionally(e);
 		}
 	}
 
 	@VisibleForTesting
-	public PendingCheckpoint triggerCheckpoint(
+	public CompletableFuture<CompletedCheckpoint> triggerCheckpoint(
 			long timestamp,
 			CheckpointProperties props,
 			@Nullable String externalSavepointLocation,
@@ -643,7 +654,7 @@ public class CheckpointCoordinator {
 				}
 
 				numUnsuccessfulCheckpointsTriggers.set(0);
-				return checkpoint;
+				return checkpoint.getCompletionFuture();
 			}
 			catch (Throwable t) {
 				// guard the map against concurrent modifications
@@ -668,7 +679,6 @@ public class CheckpointCoordinator {
 
 				throw new CheckpointException(CheckpointFailureReason.EXCEPTION, t);
 			}
-
 		} // end trigger lock
 	}
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 39a8c2e..92f42bc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
@@ -31,10 +32,8 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.OperatorStreamStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.TestLogger;
 
-import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -58,10 +57,6 @@ import static org.mockito.Mockito.when;
 @PrepareForTest(PendingCheckpoint.class)
 public class CheckpointCoordinatorFailureTest extends TestLogger {
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
-
 	/**
 	 * Tests that a failure while storing a completed checkpoint in the completed checkpoint store
 	 * will properly fail the originating pending checkpoint and clean upt the completed checkpoint.
@@ -70,6 +65,9 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 	public void testFailingCompletedCheckpointStoreAdd() throws Exception {
 		JobID jid = new JobID();
 
+		final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
+			new ManuallyTriggeredScheduledExecutor();
+
 		final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
 		final ExecutionVertex vertex = CheckpointCoordinatorTestingUtils.mockExecutionVertex(executionAttemptId);
 
@@ -99,12 +97,14 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 			new FailingCompletedCheckpointStore(),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		coord.triggerCheckpoint(triggerTimestamp, false);
 
+		manuallyTriggeredScheduledExecutor.triggerAll();
+
 		assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
 		PendingCheckpoint pendingCheckpoint = coord.getPendingCheckpoints().values().iterator().next();
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index d85d014..7c9a357 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -21,6 +21,8 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
@@ -32,9 +34,7 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 
-import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -71,10 +71,6 @@ import static org.mockito.Mockito.when;
  */
 public class CheckpointCoordinatorMasterHooksTest {
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
-
 	// ------------------------------------------------------------------------
 	//  hook registration
 	// ------------------------------------------------------------------------
@@ -194,14 +190,20 @@ public class CheckpointCoordinatorMasterHooksTest {
 		final JobID jid = new JobID();
 		final ExecutionAttemptID execId = new ExecutionAttemptID();
 		final ExecutionVertex ackVertex = mockExecutionVertex(execId);
-		final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
+		final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
+			new ManuallyTriggeredScheduledExecutor();
+		final CheckpointCoordinator cc = instantiateCheckpointCoordinator(
+			jid, manuallyTriggeredScheduledExecutor, ackVertex);
 
 		cc.addMasterHook(statefulHook1);
 		cc.addMasterHook(statelessHook);
 		cc.addMasterHook(statefulHook2);
 
 		// trigger a checkpoint
-		assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false));
+		final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+			cc.triggerCheckpoint(System.currentTimeMillis(), false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture.isCompletedExceptionally());
 		assertEquals(1, cc.getNumberOfPendingCheckpoints());
 
 		verify(statefulHook1, times(1)).triggerCheckpoint(anyLong(), anyLong(), any(Executor.class));
@@ -370,7 +372,10 @@ public class CheckpointCoordinatorMasterHooksTest {
 		final JobID jid = new JobID();
 		final ExecutionAttemptID execId = new ExecutionAttemptID();
 		final ExecutionVertex ackVertex = mockExecutionVertex(execId);
-		final CheckpointCoordinator cc = instantiateCheckpointCoordinator(jid, ackVertex);
+		final ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor =
+			new ManuallyTriggeredScheduledExecutor();
+		final CheckpointCoordinator cc = instantiateCheckpointCoordinator(
+			jid, manuallyTriggeredScheduledExecutor, ackVertex);
 
 		final MasterTriggerRestoreHook<Void> hook = mockGeneric(MasterTriggerRestoreHook.class);
 		when(hook.getIdentifier()).thenReturn(id);
@@ -391,7 +396,10 @@ public class CheckpointCoordinatorMasterHooksTest {
 		cc.addMasterHook(hook);
 
 		// trigger a checkpoint
-		assertTrue(cc.triggerCheckpoint(System.currentTimeMillis(), false));
+		final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+			cc.triggerCheckpoint(System.currentTimeMillis(), false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture.isCompletedExceptionally());
 	}
 
 
@@ -427,7 +435,18 @@ public class CheckpointCoordinatorMasterHooksTest {
 	//  utilities
 	// ------------------------------------------------------------------------
 
-	private CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) {
+	private CheckpointCoordinator instantiateCheckpointCoordinator(
+		JobID jid,
+		ExecutionVertex... ackVertices) {
+
+		return instantiateCheckpointCoordinator(jid, new ManuallyTriggeredScheduledExecutor(), ackVertices);
+	}
+
+	private CheckpointCoordinator instantiateCheckpointCoordinator(
+		JobID jid,
+		ScheduledExecutor testingScheduledExecutor,
+		ExecutionVertex... ackVertices) {
+
 		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
 			10000000L,
 			600000L,
@@ -447,7 +466,7 @@ public class CheckpointCoordinatorMasterHooksTest {
 				new StandaloneCompletedCheckpointStore(10),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				testingScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				new CheckpointFailureManager(
 					0,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
index 1725ef2..e0be7ee 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
@@ -39,7 +40,6 @@ import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.SerializableObject;
 import org.apache.flink.util.TestLogger;
 
@@ -76,6 +76,8 @@ import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUt
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 import static org.mockito.ArgumentMatchers.any;
@@ -90,9 +92,7 @@ import static org.mockito.Mockito.when;
 public class CheckpointCoordinatorRestoringTest extends TestLogger {
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
+	private ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor;
 
 	private CheckpointFailureManager failureManager;
 
@@ -104,6 +104,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 		failureManager = new CheckpointFailureManager(
 			0,
 			NoOpFailJobCall.INSTANCE);
+		manuallyTriggeredScheduledExecutor = new ManuallyTriggeredScheduledExecutor();
 	}
 
 	/**
@@ -163,14 +164,15 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		assertEquals(1, coord.getPendingCheckpoints().size());
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 
 		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
@@ -296,12 +298,15 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 				store,
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			//trigger a checkpoint and wait to become a completed checkpoint
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 
 			long checkpointId = checkpointIDCounter.getLast();
 
@@ -344,11 +349,12 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 					StateObjectCollection.singleton(serializedKeyGroupStatesForSavepoint),
 					StateObjectCollection.empty()));
 
+			manuallyTriggeredScheduledExecutor.triggerAll();
 			checkpointId = checkpointIDCounter.getLast();
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStatesForSavepoint), TASK_MANAGER_LOCATION_INFO);
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId), TASK_MANAGER_LOCATION_INFO);
 
-			assertTrue(savepointFuture.isDone());
+			assertNotNull(savepointFuture.get());
 
 			//restore and jump the latest savepoint
 			coord.restoreLatestCheckpointedState(map, true, false);
@@ -446,14 +452,15 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		assertEquals(1, coord.getPendingCheckpoints().size());
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 
 		List<KeyGroupRange> keyGroupPartitions1 =
@@ -624,14 +631,15 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		assertEquals(1, coord.getPendingCheckpoints().size());
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 
 		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
@@ -879,7 +887,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			standaloneCompletedCheckpointStore,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index ddd2f8e..22d59bc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -21,8 +21,8 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.Path;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.TestingScheduledServiceWithRecordingScheduledTasks;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
@@ -50,7 +50,6 @@ import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.TestLogger;
 
@@ -76,6 +75,7 @@ import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
@@ -106,12 +106,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
-
 	private CheckpointFailureManager failureManager;
 
+	private ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor;
+
 	@Rule
 	public TemporaryFolder tmpFolder = new TemporaryFolder();
 
@@ -120,6 +118,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		failureManager = new CheckpointFailureManager(
 			0,
 			NoOpFailJobCall.INSTANCE);
+		manuallyTriggeredScheduledExecutor = new ManuallyTriggeredScheduledExecutor();
 	}
 
 	@Test
@@ -158,7 +157,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -167,7 +166,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should not succeed
-			assertFalse(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertTrue(checkpointFuture.isCompletedExceptionally());
 
 			// still, nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -226,7 +228,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -235,7 +237,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should not succeed
-			assertFalse(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertTrue(checkpointFuture.isCompletedExceptionally());
 
 			// still, nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -285,7 +290,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -294,7 +299,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should not succeed
-			assertFalse(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertTrue(checkpointFuture.isCompletedExceptionally());
 
 			// still, nothing should be happening
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -309,7 +317,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	}
 
 	@Test
-	public void testTriggerAndDeclineCheckpointThenFailureManagerThrowsException() {
+	public void testTriggerAndDeclineCheckpointThenFailureManagerThrowsException() throws Exception {
 		final JobID jid = new JobID();
 		final long timestamp = System.currentTimeMillis();
 
@@ -335,15 +343,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				}
 			});
 
-		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
-
 		// set up the coordinator
-		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, checkpointFailureManager, scheduledExecutorService);
+		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, checkpointFailureManager, manuallyTriggeredScheduledExecutor);
 
 		try {
 			// trigger the checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkPointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkPointFuture.isCompletedExceptionally());
 
 			long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId);
@@ -389,23 +397,24 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
-			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, manuallyTriggeredScheduledExecutor);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 
 			// validate that we have a pending checkpoint
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// we have one task scheduled that will cancel after timeout
-			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(1, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId);
@@ -442,7 +451,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertTrue(checkpoint.isDiscarded());
 
 			// the canceler is also removed
-			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(0, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			// validate that we have no new pending checkpoint
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -478,26 +487,29 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			final ExecutionAttemptID attemptID2 = new ExecutionAttemptID();
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
-
-			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, manuallyTriggeredScheduledExecutor);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(0, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture1 =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture1.isCompletedExceptionally());
 
 			// trigger second checkpoint, should also succeed
-			assertTrue(coord.triggerCheckpoint(timestamp + 2, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
+				coord.triggerCheckpoint(timestamp + 2, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture2.isCompletedExceptionally());
 
 			// validate that we have a pending checkpoint
 			assertEquals(2, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(2, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(2, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			Iterator<Map.Entry<Long, PendingCheckpoint>> it = coord.getPendingCheckpoints().entrySet().iterator();
 			long checkpoint1Id = it.next().getKey();
@@ -544,7 +556,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// validate that we have only one pending checkpoint left
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(1, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			// validate that it is the same second checkpoint from earlier
 			long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
@@ -587,22 +599,23 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
-			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, manuallyTriggeredScheduledExecutor);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(0, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 
 			// validate that we have a pending checkpoint
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(1, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId);
@@ -659,7 +672,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 
 			// the canceler should be removed now
-			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(0, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			// validate that the subtasks states have registered their shared states.
 			{
@@ -684,6 +697,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// ---------------
 			final long timestampNew = timestamp + 7;
 			coord.triggerCheckpoint(timestampNew, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
 
 			long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew), TASK_MANAGER_LOCATION_INFO);
@@ -691,7 +705,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
+			assertEquals(0, manuallyTriggeredScheduledExecutor.getScheduledTasks().size());
 
 			CompletedCheckpoint successNew = coord.getSuccessfulCheckpoints().get(0);
 			assertEquals(jid, successNew.getJobId());
@@ -763,7 +777,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -771,7 +785,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp1, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture1 =
+				coord.triggerCheckpoint(timestamp1, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture1.isCompletedExceptionally());
 
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -788,7 +805,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp2, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
+				coord.triggerCheckpoint(timestamp2, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture2.isCompletedExceptionally());
 
 			assertEquals(2, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -901,7 +921,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(10),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -909,7 +929,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp1, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture1 =
+				coord.triggerCheckpoint(timestamp1, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture1.isCompletedExceptionally());
 
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -941,7 +964,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
-			assertTrue(coord.triggerCheckpoint(timestamp2, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
+				coord.triggerCheckpoint(timestamp2, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture2.isCompletedExceptionally());
 
 			assertEquals(2, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1072,12 +1098,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			// trigger a checkpoint, partially acknowledged
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().values().iterator().next();
@@ -1091,17 +1120,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1), TASK_MANAGER_LOCATION_INFO);
 
-			// wait until the checkpoint must have expired.
-			// we check every 250 msecs conservatively for 5 seconds
-			// to give even slow build servers a very good chance of completing this
-			long deadline = System.currentTimeMillis() + 5000;
-			do {
-				Thread.sleep(250);
-			}
-			while (!checkpoint.isDiscarded() &&
-					coord.getNumberOfPendingCheckpoints() > 0 &&
-					System.currentTimeMillis() < deadline);
-
+			// triggers cancelling
+			manuallyTriggeredScheduledExecutor.triggerScheduledTasks();
 			assertTrue("Checkpoint was not canceled by the timeout", checkpoint.isDiscarded());
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1157,11 +1177,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 
 			long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next();
 
@@ -1228,11 +1251,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
-		assertTrue(coord.triggerCheckpoint(timestamp, false));
+		final CompletableFuture<CompletedCheckpoint> checkpointFuture =
+			coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture.isCompletedExceptionally());
 
 		assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
@@ -1334,10 +1360,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 		ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
-		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 		// set up the coordinator and validate the initial state
-		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
+		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, manuallyTriggeredScheduledExecutor);
 
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1345,6 +1369,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// trigger the first checkpoint. this should succeed
 		String savepointDir = tmpFolder.newFolder().getAbsolutePath();
 		CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		assertFalse(savepointFuture.isDone());
 
 		// validate that we have a pending savepoint
@@ -1394,7 +1419,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// the checkpoint is internally converted to a successful checkpoint and the
 		// pending checkpoint object is disposed
 		assertTrue(pending.isDiscarded());
-		assertTrue(savepointFuture.isDone());
+		assertNotNull(savepointFuture.get());
 
 		// the now we should have a completed checkpoint
 		assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1423,6 +1448,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// ---------------
 		final long timestampNew = timestamp + 7;
 		savepointFuture = coord.triggerSavepoint(timestampNew, savepointDir);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		assertFalse(savepointFuture.isDone());
 
 		long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
@@ -1437,7 +1463,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		assertEquals(timestampNew, successNew.getTimestamp());
 		assertEquals(checkpointIdNew, successNew.getCheckpointID());
 		assertTrue(successNew.getOperatorStates().isEmpty());
-		assertTrue(savepointFuture.isDone());
+		assertNotNull(savepointFuture.get());
 
 		// validate that the first savepoint does not discard its private states.
 		verify(subtaskState1, never()).discardState();
@@ -1494,7 +1520,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(10),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1502,13 +1528,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		// Trigger savepoint and checkpoint
 		CompletableFuture<CompletedCheckpoint> savepointFuture1 = coord.triggerSavepoint(timestamp, savepointDir);
+
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		long savepointId1 = counter.getLast();
 		assertEquals(1, coord.getNumberOfPendingCheckpoints());
 
-		assertTrue(coord.triggerCheckpoint(timestamp + 1, false));
+		CompletableFuture<CompletedCheckpoint> checkpointFuture1 =
+			coord.triggerCheckpoint(timestamp + 1, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		assertEquals(2, coord.getNumberOfPendingCheckpoints());
+		assertFalse(checkpointFuture1.isCompletedExceptionally());
 
-		assertTrue(coord.triggerCheckpoint(timestamp + 2, false));
+		CompletableFuture<CompletedCheckpoint> checkpointFuture2 =
+			coord.triggerCheckpoint(timestamp + 2, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture2.isCompletedExceptionally());
 		long checkpointId2 = counter.getLast();
 		assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
@@ -1522,11 +1556,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		assertFalse(coord.getPendingCheckpoints().get(savepointId1).isDiscarded());
 		assertFalse(savepointFuture1.isDone());
 
-		assertTrue(coord.triggerCheckpoint(timestamp + 3, false));
+		CompletableFuture<CompletedCheckpoint> checkpointFuture3 =
+			coord.triggerCheckpoint(timestamp + 3, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture3.isCompletedExceptionally());
 		assertEquals(2, coord.getNumberOfPendingCheckpoints());
 
 		CompletableFuture<CompletedCheckpoint> savepointFuture2 = coord.triggerSavepoint(timestamp + 4, savepointDir);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		long savepointId2 = counter.getLast();
+		assertFalse(savepointFuture2.isCompletedExceptionally());
 		assertEquals(3, coord.getNumberOfPendingCheckpoints());
 
 		// 2nd savepoint should subsume the last checkpoint, but not the 1st savepoint
@@ -1538,7 +1577,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		assertFalse(coord.getPendingCheckpoints().get(savepointId1).isDiscarded());
 
 		assertFalse(savepointFuture1.isDone());
-		assertTrue(savepointFuture2.isDone());
+		assertNotNull(savepointFuture2.get());
 
 		// Ack first savepoint
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, savepointId1), TASK_MANAGER_LOCATION_INFO);
@@ -1546,7 +1585,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(3, coord.getNumberOfRetainedSuccessfulCheckpoints());
-		assertTrue(savepointFuture1.isDone());
+		assertNotNull(savepointFuture1.get());
 	}
 
 	private void testMaxConcurrentAttempts(int maxConcurrentAttempts) {
@@ -1595,22 +1634,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			coord.startCheckpointScheduler();
 
-			// after a while, there should be exactly as many checkpoints
-			// as concurrently permitted
-			long now = System.currentTimeMillis();
-			long timeout = now + 60000;
-			long minDuration = now + 100;
-			do {
-				Thread.sleep(20);
+			for (int i = 0; i < maxConcurrentAttempts; i++) {
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			}
-			while ((now = System.currentTimeMillis()) < minDuration ||
-					(numCalls.get() < maxConcurrentAttempts && now < timeout));
 
 			assertEquals(maxConcurrentAttempts, numCalls.get());
 
@@ -1620,18 +1652,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// now, once we acknowledge one checkpoint, it should trigger the next one
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID, 1L), TASK_MANAGER_LOCATION_INFO);
 
-			// this should have immediately triggered a new checkpoint
-			now = System.currentTimeMillis();
-			timeout = now + 60000;
-			do {
-				Thread.sleep(20);
-			}
-			while (numCalls.get() < maxConcurrentAttempts + 1 && now < timeout);
+			final Collection<ScheduledFuture<?>> periodicScheduledTasks =
+				manuallyTriggeredScheduledExecutor.getPeriodicScheduledTask();
+			assertEquals(1, periodicScheduledTasks.size());
+			final ScheduledFuture scheduledFuture = periodicScheduledTasks.iterator().next();
+
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 
 			assertEquals(maxConcurrentAttempts + 1, numCalls.get());
 
 			// no further checkpoints should happen
-			Thread.sleep(200);
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			assertEquals(maxConcurrentAttempts + 1, numCalls.get());
 
 			coord.shutdown(JobStatus.FINISHED);
@@ -1676,22 +1707,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			coord.startCheckpointScheduler();
 
-			// after a while, there should be exactly as many checkpoints
-			// as concurrently permitted
-			long now = System.currentTimeMillis();
-			long timeout = now + 60000;
-			long minDuration = now + 100;
 			do {
-				Thread.sleep(20);
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			}
-			while ((now = System.currentTimeMillis()) < minDuration ||
-					(coord.getNumberOfPendingCheckpoints() < maxConcurrentAttempts && now < timeout));
+			while (coord.getNumberOfPendingCheckpoints() < maxConcurrentAttempts);
 
 			// validate that the pending checkpoints are there
 			assertEquals(maxConcurrentAttempts, coord.getNumberOfPendingCheckpoints());
@@ -1704,12 +1729,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID, 2L), TASK_MANAGER_LOCATION_INFO);
 
 			// after a while, there should be the new checkpoints
-			final long newTimeout = System.currentTimeMillis() + 60000;
 			do {
-				Thread.sleep(20);
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			}
-			while (coord.getPendingCheckpoints().get(4L) == null &&
-					System.currentTimeMillis() < newTimeout);
+			while (coord.getNumberOfPendingCheckpoints() < maxConcurrentAttempts);
 
 			// do the final check
 			assertEquals(maxConcurrentAttempts, coord.getNumberOfPendingCheckpoints());
@@ -1760,26 +1783,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			coord.startCheckpointScheduler();
 
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			// no checkpoint should have started so far
-			Thread.sleep(200);
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 
 			// now move the state to RUNNING
 			currentState.set(ExecutionState.RUNNING);
 
 			// the coordinator should start checkpointing now
-			final long timeout = System.currentTimeMillis() + 10000;
-			do {
-				Thread.sleep(20);
-			}
-			while (System.currentTimeMillis() < timeout &&
-					coord.getNumberOfPendingCheckpoints() == 0);
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 
 			assertTrue(coord.getNumberOfPendingCheckpoints() > 0);
 		}
@@ -1795,6 +1813,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	@Test
 	public void testConcurrentSavepoints() throws Exception {
 		JobID jobId = new JobID();
+		int numSavepoints = 5;
 
 		final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
 		ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
@@ -1820,14 +1839,12 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		List<CompletableFuture<CompletedCheckpoint>> savepointFutures = new ArrayList<>();
 
-		int numSavepoints = 5;
-
 		String savepointDir = tmpFolder.newFolder().getAbsolutePath();
 
 		// Trigger savepoints
@@ -1840,6 +1857,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertFalse(savepointFuture.isDone());
 		}
 
+		manuallyTriggeredScheduledExecutor.triggerAll();
+
 		// ACK all savepoints
 		long checkpointId = checkpointIDCounter.getLast();
 		for (int i = 0; i < numSavepoints; i++, checkpointId--) {
@@ -1848,7 +1867,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		// After ACKs, all should be completed
 		for (CompletableFuture<CompletedCheckpoint> savepointFuture : savepointFutures) {
-			assertTrue(savepointFuture.isDone());
+			assertNotNull(savepointFuture.get());
 		}
 	}
 
@@ -1881,7 +1900,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1927,11 +1946,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
+			CompletableFuture<CompletedCheckpoint> checkpointFuture =
+				coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
+			assertFalse(checkpointFuture.isCompletedExceptionally());
 
 			for (PendingCheckpoint checkpoint : coord.getPendingCheckpoints().values()) {
 				CheckpointProperties props = checkpoint.getProps();
@@ -2140,7 +2162,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	 * Tests that the pending checkpoint stats callbacks are created.
 	 */
 	@Test
-	public void testCheckpointStatsTrackerPendingCheckpointCallback() {
+	public void testCheckpointStatsTrackerPendingCheckpointCallback() throws Exception {
 		final long timestamp = System.currentTimeMillis();
 		ExecutionVertex vertex1 = mockExecutionVertex(new ExecutionAttemptID());
 
@@ -2164,7 +2186,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -2175,7 +2197,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			.thenReturn(mock(PendingCheckpointStats.class));
 
 		// Trigger a checkpoint and verify callback
-		assertTrue(coord.triggerCheckpoint(timestamp, false));
+		CompletableFuture<CompletedCheckpoint> checkpointFuture =
+			coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
+		assertFalse(checkpointFuture.isCompletedExceptionally());
 
 		verify(tracker, times(1))
 			.reportPendingCheckpoint(eq(1L), eq(timestamp), eq(CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION)));
@@ -2210,7 +2235,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -2280,7 +2305,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 				deleteExecutor -> {
 					SharedStateRegistry instance = new SharedStateRegistry(deleteExecutor);
 					createdSharedStateRegistries.add(instance);
@@ -2421,8 +2446,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		final ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 		final ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
-		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
-			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 		// set up the coordinator and validate the initial state
 		final CheckpointCoordinator coordinator = getCheckpointCoordinator(jobId, vertex1, vertex2,
 				new CheckpointFailureManager(
@@ -2439,11 +2462,12 @@ public class CheckpointCoordinatorTest extends TestLogger {
 							throw new AssertionError("This method should not be called for the test.");
 						}
 					}),
-			scheduledExecutorService);
+			manuallyTriggeredScheduledExecutor);
 
 		final CompletableFuture<CompletedCheckpoint> savepointFuture = coordinator
 				.triggerSynchronousSavepoint(10L, false, "test-dir");
 
+		manuallyTriggeredScheduledExecutor.triggerAll();
 		final PendingCheckpoint syncSavepoint = declineSynchronousSavepoint(jobId, coordinator, attemptID1, expectedRootCause);
 
 		assertTrue(syncSavepoint.isDiscarded());
@@ -2519,8 +2543,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		// trigger the checkpoint
 		coord.triggerCheckpoint(timestamp, false);
+		manuallyTriggeredScheduledExecutor.triggerAll();
 
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		assertEquals(1, coord.getPendingCheckpoints().size());
 		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
index 76f042b..89af153 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -18,17 +18,24 @@
 
 package org.apache.flink.runtime.checkpoint;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.runtime.concurrent.ScheduledExecutor;
+import org.apache.flink.mock.Whitebox;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway;
+import org.apache.flink.runtime.executiongraph.utils.SimpleAckingTaskManagerGateway.CheckpointConsumer;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
+import org.apache.flink.runtime.jobmaster.LogicalSlot;
+import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
 import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
@@ -42,6 +49,9 @@ import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 
 import org.junit.Assert;
+import org.mockito.invocation.InvocationOnMock;
+
+import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.io.Serializable;
@@ -50,23 +60,20 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
-import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
-import java.util.Set;
 import java.util.UUID;
-import java.util.concurrent.Callable;
-import java.util.concurrent.Delayed;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
-import java.util.concurrent.ScheduledFuture;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.when;
@@ -306,7 +313,7 @@ public class CheckpointCoordinatorTestingUtils {
 	static ExecutionJobVertex mockExecutionJobVertex(
 		JobVertexID jobVertexID,
 		int parallelism,
-		int maxParallelism) {
+		int maxParallelism) throws Exception {
 
 		return mockExecutionJobVertex(
 			jobVertexID,
@@ -320,7 +327,7 @@ public class CheckpointCoordinatorTestingUtils {
 		JobVertexID jobVertexID,
 		List<OperatorID> jobVertexIDs,
 		int parallelism,
-		int maxParallelism) {
+		int maxParallelism) throws Exception {
 		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
 
 		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
@@ -348,12 +355,39 @@ public class CheckpointCoordinatorTestingUtils {
 		return executionJobVertex;
 	}
 
-	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) throws Exception {
+		return mockExecutionVertex(attemptID, (LogicalSlot) null);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		CheckpointConsumer checkpointConsumer) throws Exception {
+
+		final SimpleAckingTaskManagerGateway taskManagerGateway = new SimpleAckingTaskManagerGateway();
+		taskManagerGateway.setCheckpointConsumer(checkpointConsumer);
+		return mockExecutionVertex(attemptID, taskManagerGateway);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		TaskManagerGateway taskManagerGateway) throws Exception {
+
+		TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
+		slotBuilder.setTaskManagerGateway(taskManagerGateway);
+		LogicalSlot	slot = slotBuilder.createTestingLogicalSlot();
+		return mockExecutionVertex(attemptID, slot);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		@Nullable LogicalSlot slot) throws Exception {
+
 		JobVertexID jobVertexID = new JobVertexID();
 		return mockExecutionVertex(
 			attemptID,
 			jobVertexID,
 			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+			slot,
 			1,
 			1,
 			ExecutionState.RUNNING);
@@ -366,7 +400,28 @@ public class CheckpointCoordinatorTestingUtils {
 		int parallelism,
 		int maxParallelism,
 		ExecutionState state,
-		ExecutionState ... successiveStates) {
+		ExecutionState ... successiveStates) throws Exception {
+
+		return mockExecutionVertex(
+			attemptID,
+			jobVertexID,
+			jobVertexIDs,
+			null,
+			parallelism,
+			maxParallelism,
+			state,
+			successiveStates);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		JobVertexID jobVertexID,
+		List<OperatorID> jobVertexIDs,
+		@Nullable LogicalSlot slot,
+		int parallelism,
+		int maxParallelism,
+		ExecutionState state,
+		ExecutionState ... successiveStates) throws Exception {
 
 		ExecutionVertex vertex = mock(ExecutionVertex.class);
 
@@ -378,6 +433,15 @@ public class CheckpointCoordinatorTestingUtils {
 			1L,
 			Time.milliseconds(500L)
 		));
+		if (slot != null) {
+			// is there a better way to do this?
+			//noinspection unchecked
+			AtomicReferenceFieldUpdater<Execution, LogicalSlot> slotUpdater =
+				(AtomicReferenceFieldUpdater<Execution, LogicalSlot>)
+					Whitebox.getInternalState(exec, "ASSIGNED_SLOT_UPDATER");
+			slotUpdater.compareAndSet(exec, null, slot);
+		}
+
 		when(exec.getAttemptId()).thenReturn(attemptID);
 		when(exec.getState()).thenReturn(state, successiveStates);
 
@@ -455,6 +519,29 @@ public class CheckpointCoordinatorTestingUtils {
 		return mock;
 	}
 
+	static Execution mockExecution(CheckpointConsumer checkpointConsumer) {
+		ExecutionVertex executionVertex = mock(ExecutionVertex.class);
+		final JobID jobId = new JobID();
+		when(executionVertex.getJobId()).thenReturn(jobId);
+		Execution mock = mock(Execution.class);
+		ExecutionAttemptID executionAttemptID = new ExecutionAttemptID();
+		when(mock.getAttemptId()).thenReturn(executionAttemptID);
+		when(mock.getState()).thenReturn(ExecutionState.RUNNING);
+		when(mock.getVertex()).thenReturn(executionVertex);
+		doAnswer((InvocationOnMock invocation) -> {
+			final Object[] args = invocation.getArguments();
+			checkpointConsumer.accept(
+				executionAttemptID,
+				jobId,
+				(long) args[0],
+				(long) args[1],
+				(CheckpointOptions) args[2],
+				false);
+			return null;
+		}).when(mock).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+		return mock;
+	}
+
 	static ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
 		ExecutionVertex mock = mock(ExecutionVertex.class);
 		when(mock.getJobvertexId()).thenReturn(vertexId);
@@ -479,115 +566,4 @@ public class CheckpointCoordinatorTestingUtils {
 		}
 		return vertex;
 	}
-
-	static class TestingScheduledServiceWithRecordingScheduledTasks implements ScheduledExecutor {
-
-		private final ScheduledExecutor scheduledExecutor;
-
-		private final Set<UUID> tasksScheduledOnce;
-
-		public TestingScheduledServiceWithRecordingScheduledTasks(ScheduledExecutor scheduledExecutor) {
-			this.scheduledExecutor = checkNotNull(scheduledExecutor);
-			tasksScheduledOnce = new HashSet<>();
-		}
-
-		public int getNumScheduledOnceTasks() {
-			synchronized (tasksScheduledOnce) {
-				return tasksScheduledOnce.size();
-			}
-		}
-
-		@Override
-		public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
-			final UUID id = UUID.randomUUID();
-			synchronized (tasksScheduledOnce) {
-				tasksScheduledOnce.add(id);
-			}
-			return new TestingScheduledFuture<>(id, scheduledExecutor.schedule(() -> {
-				synchronized (tasksScheduledOnce) {
-					tasksScheduledOnce.remove(id);
-				}
-				command.run();
-			}, delay, unit));
-		}
-
-		@Override
-		public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
-			final UUID id = UUID.randomUUID();
-			synchronized (tasksScheduledOnce) {
-				tasksScheduledOnce.add(id);
-			}
-			return new TestingScheduledFuture<>(id, scheduledExecutor.schedule(() -> {
-				synchronized (tasksScheduledOnce) {
-					tasksScheduledOnce.remove(id);
-				}
-				return callable.call();
-			}, delay, unit));
-		}
-
-		@Override
-		public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) {
-			return scheduledExecutor.scheduleAtFixedRate(command, initialDelay, period, unit);
-		}
-
-		@Override
-		public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) {
-			return scheduledExecutor.scheduleWithFixedDelay(command, initialDelay, delay, unit);
-		}
-
-		@Override
-		public void execute(Runnable command) {
-			scheduledExecutor.execute(command);
-		}
-
-		private class TestingScheduledFuture<V> implements ScheduledFuture<V> {
-
-			private final ScheduledFuture<V> scheduledFuture;
-
-			private final UUID id;
-
-			public TestingScheduledFuture(UUID id, ScheduledFuture<V> scheduledFuture) {
-				this.id = checkNotNull(id);
-				this.scheduledFuture = checkNotNull(scheduledFuture);
-			}
-
-			@Override
-			public long getDelay(TimeUnit unit) {
-				return scheduledFuture.getDelay(unit);
-			}
-
-			@Override
-			public int compareTo(Delayed o) {
-				return scheduledFuture.compareTo(o);
-			}
-
-			@Override
-			public boolean cancel(boolean mayInterruptIfRunning) {
-				synchronized (tasksScheduledOnce) {
-					tasksScheduledOnce.remove(id);
-				}
-				return scheduledFuture.cancel(mayInterruptIfRunning);
-			}
-
-			@Override
-			public boolean isCancelled() {
-				return scheduledFuture.isCancelled();
-			}
-
-			@Override
-			public boolean isDone() {
-				return scheduledFuture.isDone();
-			}
-
-			@Override
-			public V get() throws InterruptedException, ExecutionException {
-				return scheduledFuture.get();
-			}
-
-			@Override
-			public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
-				return scheduledFuture.get(timeout, unit);
-			}
-		}
-	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
index a157b6f..4a761d9 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
@@ -28,11 +29,9 @@ import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguratio
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -55,9 +54,7 @@ import static org.mockito.Mockito.doAnswer;
 public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
+	private ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor;
 
 	private CheckpointFailureManager failureManager;
 
@@ -66,6 +63,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 		failureManager = new CheckpointFailureManager(
 			0,
 			NoOpFailJobCall.INSTANCE);
+		manuallyTriggeredScheduledExecutor = new ManuallyTriggeredScheduledExecutor();
 	}
 
 	@Test
@@ -128,51 +126,39 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			coord.startCheckpointScheduler();
 
-			long timeout = System.currentTimeMillis() + 60000;
 			do {
-				Thread.sleep(20);
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			}
-			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
-			assertTrue(numCalls.get() >= 5);
+			while (numCalls.get() < 5);
+			assertEquals(5, numCalls.get());
 
 			coord.stopCheckpointScheduler();
 
-			// for 400 ms, no further calls may come.
-			// there may be the case that one trigger was fired and about to
-			// acquire the lock, such that after cancelling it will still do
-			// the remainder of its work
-			int numCallsSoFar = numCalls.get();
-			Thread.sleep(400);
-			assertTrue(numCallsSoFar == numCalls.get() ||
-				numCallsSoFar + 1 == numCalls.get());
+			// no further calls may come.
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
+			assertEquals(5, numCalls.get());
 
 			// start another sequence of periodic scheduling
 			numCalls.set(0);
 			coord.startCheckpointScheduler();
 
-			timeout = System.currentTimeMillis() + 60000;
 			do {
-				Thread.sleep(20);
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 			}
-			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
-			assertTrue(numCalls.get() >= 5);
+			while (numCalls.get() < 5);
+			assertEquals(5, numCalls.get());
 
 			coord.stopCheckpointScheduler();
 
-			// for 400 ms, no further calls may come
-			// there may be the case that one trigger was fired and about to
-			// acquire the lock, such that after cancelling it will still do
-			// the remainder of its work
-			numCallsSoFar = numCalls.get();
-			Thread.sleep(400);
-			assertTrue(numCallsSoFar == numCalls.get() ||
-				numCallsSoFar + 1 == numCalls.get());
+			// no further calls may come
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
+			assertEquals(5, numCalls.get());
 
 			coord.shutdown(JobStatus.FINISHED);
 		}
@@ -203,9 +189,10 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 		}).when(executionAttempt).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
 
 		final long delay = 50;
+		final long checkpointInterval = 12;
 
 		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			12,           // periodic interval is 12 ms
+			checkpointInterval,           // periodic interval is 12 ms
 			200_000,     // timeout is very long (200 s)
 			delay,       // 50 ms delay between checkpoints
 			1,
@@ -223,12 +210,13 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
 		try {
 			coord.startCheckpointScheduler();
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
 
 			// wait until the first checkpoint was triggered
 			Long firstCallId = triggerCalls.take();
@@ -240,6 +228,12 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 			final long ackTime = System.nanoTime();
 			coord.receiveAcknowledgeMessage(ackMsg, TASK_MANAGER_LOCATION_INFO);
 
+			manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
+			while (triggerCalls.isEmpty()) {
+				// sleeps for a while to simulate periodic scheduling
+				Thread.sleep(checkpointInterval);
+				manuallyTriggeredScheduledExecutor.triggerPeriodicScheduledTasks();
+			}
 			// wait until the next checkpoint is triggered
 			Long nextCallId = triggerCalls.take();
 			final long nextCheckpointTime = System.nanoTime();
@@ -284,7 +278,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -296,6 +290,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 				null,
 				true,
 				false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
 			fail("The triggerCheckpoint call expected an exception");
 		} catch (CheckpointException e) {
 			assertEquals(CheckpointFailureReason.PERIODIC_SCHEDULER_SHUTDOWN, e.getCheckpointFailureReason());
@@ -309,6 +304,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 				null,
 				false,
 				false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
 		} catch (CheckpointException e) {
 			fail("Unexpected exception : " + e.getCheckpointFailureReason().message());
 		}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 080e1c7..ca5a42e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -34,13 +35,11 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
-import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.SerializableObject;
 
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Before;
-import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.Mockito;
 import org.mockito.hamcrest.MockitoHamcrest;
@@ -65,9 +64,7 @@ public class CheckpointStateRestoreTest {
 
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
-	@Rule
-	public final TestingScheduledExecutor testingScheduledExecutor =
-		new TestingScheduledExecutor();
+	private ManuallyTriggeredScheduledExecutor manuallyTriggeredScheduledExecutor;
 
 	private CheckpointFailureManager failureManager;
 
@@ -76,6 +73,7 @@ public class CheckpointStateRestoreTest {
 		failureManager = new CheckpointFailureManager(
 			0,
 			NoOpFailJobCall.INSTANCE);
+		manuallyTriggeredScheduledExecutor = new ManuallyTriggeredScheduledExecutor();
 	}
 
 	/**
@@ -133,13 +131,14 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
 			// create ourselves a checkpoint with state
 			final long timestamp = 34623786L;
 			coord.triggerCheckpoint(timestamp, false);
+			manuallyTriggeredScheduledExecutor.triggerAll();
 
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
@@ -218,7 +217,7 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
-				testingScheduledExecutor.getScheduledExecutor(),
+				manuallyTriggeredScheduledExecutor,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -284,7 +283,7 @@ public class CheckpointStateRestoreTest {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
-			testingScheduledExecutor.getScheduledExecutor(),
+			manuallyTriggeredScheduledExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
index df3e97e..ab1424d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
@@ -92,8 +92,9 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
 
 		checkpointCoordinator.startCheckpointScheduler();
 		assertTrue(checkpointCoordinator.isCurrentPeriodicTriggerAvailable());
-		manualThreadExecutor.triggerAll();
-		manualThreadExecutor.triggerScheduledTasks();
+		// only trigger the periodic scheduling
+		// we can't trigger all scheduled task, because there is also a cancellation scheduled
+		manualThreadExecutor.triggerPeriodicScheduledTasks();
 		assertEquals(1, checkpointCoordinator.getNumberOfPendingCheckpoints());
 
 		for (int i = 1; i < maxConcurrentCheckpoints; i++) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ManuallyTriggeredScheduledExecutor.java b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ManuallyTriggeredScheduledExecutor.java
index 870cda8..de2ec72 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ManuallyTriggeredScheduledExecutor.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/concurrent/ManuallyTriggeredScheduledExecutor.java
@@ -35,6 +35,7 @@ import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledFuture;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.stream.Collectors;
 
 /**
  * Simple {@link ScheduledExecutor} implementation for testing purposes.
@@ -42,8 +43,14 @@ import java.util.concurrent.TimeoutException;
 public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 
 	private final Executor executorDelegate;
+
 	private final ArrayDeque<Runnable> queuedRunnables = new ArrayDeque<>();
-	private final ConcurrentLinkedQueue<ScheduledTask<?>> scheduledTasks = new ConcurrentLinkedQueue<>();
+
+	private final ConcurrentLinkedQueue<ScheduledTask<?>> nonPeriodicScheduledTasks =
+		new ConcurrentLinkedQueue<>();
+
+	private final ConcurrentLinkedQueue<ScheduledTask<?>> periodicScheduledTasks =
+		new ConcurrentLinkedQueue<>();
 
 	public ManuallyTriggeredScheduledExecutor() {
 		this.executorDelegate = Runnable::run;
@@ -74,11 +81,7 @@ public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 			next = queuedRunnables.removeFirst();
 		}
 
-		if (next != null) {
-			CompletableFuture.runAsync(next, executorDelegate).join();
-		} else {
-			throw new IllegalStateException("No runnable available");
-		}
+		CompletableFuture.runAsync(next, executorDelegate).join();
 	}
 
 	/**
@@ -92,58 +95,107 @@ public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 
 	@Override
 	public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
-		return insertRunnable(command, false);
+		return insertNonPeriodicTask(command, delay, unit);
 	}
 
 	@Override
 	public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
-		final ScheduledTask<V> scheduledTask = new ScheduledTask<>(callable, false);
-
-		scheduledTasks.offer(scheduledTask);
-
-		return scheduledTask;
+		return insertNonPeriodicTask(callable, delay, unit);
 	}
 
 	@Override
 	public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) {
-		return insertRunnable(command, true);
+		return insertPeriodicRunnable(command, initialDelay, period, unit);
 	}
 
 	@Override
 	public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) {
-		return insertRunnable(command, true);
+		return insertPeriodicRunnable(command, initialDelay, delay, unit);
 	}
 
 	public Collection<ScheduledFuture<?>> getScheduledTasks() {
-		return new ArrayList<>(scheduledTasks);
+		final ArrayList<ScheduledFuture<?>> scheduledTasks =
+			new ArrayList<>(nonPeriodicScheduledTasks.size() + periodicScheduledTasks.size());
+		scheduledTasks.addAll(getNonPeriodicScheduledTask());
+		scheduledTasks.addAll(getPeriodicScheduledTask());
+		return scheduledTasks;
+	}
+
+	public Collection<ScheduledFuture<?>> getPeriodicScheduledTask() {
+		return periodicScheduledTasks
+			.stream()
+			.filter(scheduledTask -> !scheduledTask.isCancelled())
+			.collect(Collectors.toList());
+	}
+
+	public Collection<ScheduledFuture<?>> getNonPeriodicScheduledTask() {
+		return nonPeriodicScheduledTasks
+			.stream()
+			.filter(scheduledTask -> !scheduledTask.isCancelled())
+			.collect(Collectors.toList());
 	}
 
 	/**
 	 * Triggers all registered tasks.
 	 */
 	public void triggerScheduledTasks() {
-		final Iterator<ScheduledTask<?>> iterator = scheduledTasks.iterator();
+		triggerPeriodicScheduledTasks();
+		triggerNonPeriodicScheduledTasks();
+	}
+
+	public void triggerNonPeriodicScheduledTasks() {
+		final Iterator<ScheduledTask<?>> iterator = nonPeriodicScheduledTasks.iterator();
 
 		while (iterator.hasNext()) {
 			final ScheduledTask<?> scheduledTask = iterator.next();
 
-			scheduledTask.execute();
+			if (!scheduledTask.isCancelled()) {
+				scheduledTask.execute();
+			}
+			iterator.remove();
+		}
+	}
 
-			if (!scheduledTask.isPeriodic) {
-				iterator.remove();
+	public void triggerPeriodicScheduledTasks() {
+		for (ScheduledTask<?> scheduledTask : periodicScheduledTasks) {
+			if (!scheduledTask.isCancelled()) {
+				scheduledTask.execute();
 			}
 		}
 	}
 
-	private ScheduledFuture<?> insertRunnable(Runnable command, boolean isPeriodic) {
+	private ScheduledFuture<?> insertPeriodicRunnable(
+		Runnable command,
+		long delay,
+		long period,
+		TimeUnit unit) {
+
 		final ScheduledTask<?> scheduledTask = new ScheduledTask<>(
 			() -> {
 				command.run();
 				return null;
 			},
-			isPeriodic);
+			unit.convert(delay, TimeUnit.MILLISECONDS),
+			unit.convert(period, TimeUnit.MILLISECONDS));
 
-		scheduledTasks.offer(scheduledTask);
+		periodicScheduledTasks.offer(scheduledTask);
+
+		return scheduledTask;
+	}
+
+	private ScheduledFuture<?> insertNonPeriodicTask(Runnable command, long delay, TimeUnit unit) {
+		return insertNonPeriodicTask(() -> {
+			command.run();
+			return null;
+		}, delay, unit);
+	}
+
+	private <V> ScheduledFuture<V> insertNonPeriodicTask(
+		Callable<V> callable, long delay, TimeUnit unit) {
+		final ScheduledTask<V> scheduledTask =
+			new ScheduledTask<>(callable, unit.convert(delay, TimeUnit.MILLISECONDS));
+
+		nonPeriodicScheduledTasks.offer(scheduledTask);
 
 		return scheduledTask;
 	}
@@ -152,20 +204,30 @@ public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 
 		private final Callable<T> callable;
 
-		private final boolean isPeriodic;
+		private final long delay;
+
+		private final long period;
 
 		private final CompletableFuture<T> result;
 
-		private ScheduledTask(Callable<T> callable, boolean isPeriodic) {
-			this.callable = Preconditions.checkNotNull(callable);
-			this.isPeriodic = isPeriodic;
+		private ScheduledTask(Callable<T> callable, long delay) {
+			this(callable, delay, 0);
+		}
 
+		private ScheduledTask(Callable<T> callable, long delay, long period) {
+			this.callable = Preconditions.checkNotNull(callable);
 			this.result = new CompletableFuture<>();
+			this.delay = delay;
+			this.period = period;
+		}
+
+		private boolean isPeriodic() {
+			return period > 0;
 		}
 
 		public void execute() {
 			if (!result.isDone()) {
-				if (!isPeriodic) {
+				if (!isPeriodic()) {
 					try {
 						result.complete(callable.call());
 					} catch (Exception e) {
@@ -183,12 +245,12 @@ public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 
 		@Override
 		public long getDelay(TimeUnit unit) {
-			return 0;
+			return unit.convert(delay, TimeUnit.MILLISECONDS);
 		}
 
 		@Override
 		public int compareTo(Delayed o) {
-			return 0;
+			return Long.compare(getDelay(TimeUnit.MILLISECONDS), o.getDelay(TimeUnit.MILLISECONDS));
 		}
 
 		@Override
@@ -212,7 +274,8 @@ public class ManuallyTriggeredScheduledExecutor implements ScheduledExecutor {
 		}
 
 		@Override
-		public T get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
+		public T get(long timeout, @Nonnull TimeUnit unit)
+			throws InterruptedException, ExecutionException, TimeoutException {
 			return result.get(timeout, unit);
 		}
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
index e09d8be..3408e11 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
@@ -20,7 +20,6 @@ package org.apache.flink.runtime.executiongraph.utils;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
-import org.apache.flink.api.java.tuple.Tuple6;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.concurrent.FutureUtils;
@@ -56,7 +55,13 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 
 	private BiConsumer<JobID, Collection<ResultPartitionID>> releasePartitionsConsumer = (ignore1, ignore2) -> { };
 
-	private Consumer<Tuple6<ExecutionAttemptID, JobID, Long, Long, CheckpointOptions, Boolean>> checkpointConsumer = ignore -> { };
+	private CheckpointConsumer checkpointConsumer = (
+		executionAttemptID,
+		jobId,
+		checkpointId,
+		timestamp,
+		checkpointOptions,
+		advanceToEndOfEventTime) -> { };
 
 	public void setSubmitConsumer(Consumer<TaskDeploymentDescriptor> submitConsumer) {
 		this.submitConsumer = submitConsumer;
@@ -74,7 +79,7 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 		this.releasePartitionsConsumer = releasePartitionsConsumer;
 	}
 
-	public void setCheckpointConsumer(Consumer<Tuple6<ExecutionAttemptID, JobID, Long, Long, CheckpointOptions, Boolean>> checkpointConsumer) {
+	public void setCheckpointConsumer(CheckpointConsumer checkpointConsumer) {
 		this.checkpointConsumer = checkpointConsumer;
 	}
 
@@ -131,7 +136,14 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 			long timestamp,
 			CheckpointOptions checkpointOptions,
 			boolean advanceToEndOfEventTime) {
-		checkpointConsumer.accept(Tuple6.of(executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions, advanceToEndOfEventTime));
+
+		checkpointConsumer.accept(
+			executionAttemptID,
+			jobId,
+			checkpointId,
+			timestamp,
+			checkpointOptions,
+			advanceToEndOfEventTime);
 	}
 
 	@Override
@@ -144,4 +156,18 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 			return CompletableFuture.completedFuture(Acknowledge.get());
 		}
 	}
+
+	/**
+	 * Consumer that accepts checkpoint trigger information.
+	 */
+	public interface CheckpointConsumer {
+
+		void accept(
+			ExecutionAttemptID executionAttemptID,
+			JobID jobId,
+			long checkpointId,
+			long timestamp,
+			CheckpointOptions checkpointOptions,
+			boolean advanceToEndOfEventTime);
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java
index 930af9f..d9cfb11 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java
@@ -38,15 +38,15 @@ public class TestingScheduledExecutor extends ExternalResource {
 	private ScheduledExecutorService innerExecutorService;
 
 	public TestingScheduledExecutor() {
-			this(500L);
-		}
+		this(500L);
+	}
 
 	public TestingScheduledExecutor(long shutdownTimeoutMillis) {
 		this.shutdownTimeoutMillis = shutdownTimeoutMillis;
 	}
 
 	@Override
-	protected void before() {
+	public void before() {
 		this.innerExecutorService = Executors.newSingleThreadScheduledExecutor();
 		this.scheduledExecutor = new ScheduledExecutorServiceAdapter(innerExecutorService);
 	}
@@ -56,7 +56,7 @@ public class TestingScheduledExecutor extends ExternalResource {
 		ExecutorUtils.gracefulShutdown(shutdownTimeoutMillis, TimeUnit.MILLISECONDS, innerExecutorService);
 	}
 
-	public ScheduledExecutor getScheduledExecutor() {
+	protected ScheduledExecutor getScheduledExecutor() {
 		return scheduledExecutor;
 	}
 }


[flink] 03/08: [hotfix] Correct code style of CheckpointCoordinator

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 3c84c058c67a02671cf325b27576b6faaba1d836
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Thu Sep 19 21:23:35 2019 +0800

    [hotfix] Correct code style of CheckpointCoordinator
---
 .../org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java    | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index f30b135..e7c7e17 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -1249,7 +1249,7 @@ public class CheckpointCoordinator {
 	}
 
 	/**
-	 * If too many checkpoints are currently in progress, we need to mark that a request is queued
+	 * If too many checkpoints are currently in progress, we need to mark that a request is queued.
 	 *
 	 * @throws CheckpointException If too many checkpoints are currently in progress.
 	 */
@@ -1265,7 +1265,7 @@ public class CheckpointCoordinator {
 	}
 
 	/**
-	 * Make sure the minimum interval between checkpoints has passed
+	 * Make sure the minimum interval between checkpoints has passed.
 	 *
 	 * @throws CheckpointException If the minimum interval between checkpoints has not passed.
 	 */


[flink] 05/08: [FLINK-13904][tests] Support checkpoint consumer of SimpleAckingTaskManagerGateway

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit fc196737b818261d039d6ecb2c1555c340f0e2c0
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Sun Sep 29 14:50:09 2019 +0800

    [FLINK-13904][tests] Support checkpoint consumer of SimpleAckingTaskManagerGateway
---
 .../executiongraph/utils/SimpleAckingTaskManagerGateway.java  | 11 ++++++++++-
 1 file changed, 10 insertions(+), 1 deletion(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
index 0d07f3d..e09d8be 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/utils/SimpleAckingTaskManagerGateway.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.executiongraph.utils;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
+import org.apache.flink.api.java.tuple.Tuple6;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.concurrent.FutureUtils;
@@ -55,6 +56,8 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 
 	private BiConsumer<JobID, Collection<ResultPartitionID>> releasePartitionsConsumer = (ignore1, ignore2) -> { };
 
+	private Consumer<Tuple6<ExecutionAttemptID, JobID, Long, Long, CheckpointOptions, Boolean>> checkpointConsumer = ignore -> { };
+
 	public void setSubmitConsumer(Consumer<TaskDeploymentDescriptor> submitConsumer) {
 		this.submitConsumer = submitConsumer;
 	}
@@ -71,6 +74,10 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 		this.releasePartitionsConsumer = releasePartitionsConsumer;
 	}
 
+	public void setCheckpointConsumer(Consumer<Tuple6<ExecutionAttemptID, JobID, Long, Long, CheckpointOptions, Boolean>> checkpointConsumer) {
+		this.checkpointConsumer = checkpointConsumer;
+	}
+
 	@Override
 	public String getAddress() {
 		return address;
@@ -123,7 +130,9 @@ public class SimpleAckingTaskManagerGateway implements TaskManagerGateway {
 			long checkpointId,
 			long timestamp,
 			CheckpointOptions checkpointOptions,
-			boolean advanceToEndOfEventTime) {}
+			boolean advanceToEndOfEventTime) {
+		checkpointConsumer.accept(Tuple6.of(executionAttemptID, jobId, checkpointId, timestamp, checkpointOptions, advanceToEndOfEventTime));
+	}
 
 	@Override
 	public CompletableFuture<Acknowledge> freeSlot(AllocationID allocationId, Throwable cause, Time timeout) {


[flink] 02/08: [hotfix] Split too large file CheckpointCoordinatorTest.java into several small files

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 22c32483f800da95fb2281ac58f287e866d0e96b
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 18 19:25:12 2019 +0800

    [hotfix] Split too large file CheckpointCoordinatorTest.java into several small files
---
 .../CheckpointCoordinatorRestoringTest.java        |  996 ++++++++++++++++
 .../checkpoint/CheckpointCoordinatorTest.java      | 1256 +-------------------
 .../CheckpointCoordinatorTriggeringTest.java       |  308 +++++
 3 files changed, 1353 insertions(+), 1207 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
new file mode 100644
index 0000000..1259144
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
@@ -0,0 +1,996 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
+import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
+import org.apache.flink.util.SerializableObject;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
+
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.mockito.Mockito;
+import org.mockito.hamcrest.MockitoHamcrest;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
+
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for restoring checkpoint.
+ */
+public class CheckpointCoordinatorRestoringTest extends TestLogger {
+	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
+
+	private CheckpointFailureManager failureManager;
+
+	@Rule
+	public TemporaryFolder tmpFolder = new TemporaryFolder();
+
+	@Before
+	public void setUp() throws Exception {
+		failureManager = new CheckpointFailureManager(
+			0,
+			NoOpFailJobCall.INSTANCE);
+	}
+
+	/**
+	 * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to
+	 * the {@link Execution} upon recovery.
+	 *
+	 * @throws Exception
+	 */
+	@Test
+	public void testRestoreLatestCheckpointedState() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices =
+			allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+
+		CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore();
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			chkConfig,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			new StandaloneCheckpointIDCounter(),
+			store,
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp, false);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				subtaskState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				subtaskState);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		// shutdown the store
+		store.shutdown(JobStatus.SUSPENDED);
+
+		// restore the store
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		tasks.put(jobVertexID1, jobVertex1);
+		tasks.put(jobVertexID2, jobVertex2);
+
+		coord.restoreLatestCheckpointedState(tasks, true, false);
+
+		// validate that all shared states are registered again after the recovery.
+		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
+			for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) {
+				for (OperatorSubtaskState subtaskState : taskState.getStates()) {
+					verify(subtaskState, times(2)).registerSharedStates(any(SharedStateRegistry.class));
+				}
+			}
+		}
+
+		// verify the restored state
+		verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
+		verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleIn() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointedStateScaleOut() throws Exception {
+		testRestoreLatestCheckpointedStateWithChangingParallelism(true);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointWhenPreferCheckpoint() throws Exception {
+		testRestoreLatestCheckpointIsPreferSavepoint(true);
+	}
+
+	@Test
+	public void testRestoreLatestCheckpointWhenPreferSavepoint() throws Exception {
+		testRestoreLatestCheckpointIsPreferSavepoint(false);
+	}
+
+	private void testRestoreLatestCheckpointIsPreferSavepoint(boolean isPreferCheckpoint) {
+		try {
+			final JobID jid = new JobID();
+			long timestamp = System.currentTimeMillis();
+			StandaloneCheckpointIDCounter checkpointIDCounter = new StandaloneCheckpointIDCounter();
+
+			final JobVertexID statefulId = new JobVertexID();
+			final JobVertexID statelessId = new JobVertexID();
+
+			Execution statefulExec1 = mockExecution();
+			Execution statelessExec1 = mockExecution();
+
+			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 1);
+			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 1);
+
+			ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId,
+				new ExecutionVertex[] { stateful1 });
+			ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId,
+				new ExecutionVertex[] { stateless1 });
+
+			Map<JobVertexID, ExecutionJobVertex> map = new HashMap<JobVertexID, ExecutionJobVertex>();
+			map.put(statefulId, stateful);
+			map.put(statelessId, stateless);
+
+			CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(2);
+
+			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+				600000,
+				600000,
+				0,
+				Integer.MAX_VALUE,
+				CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+				true,
+				isPreferCheckpoint,
+				0);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				chkConfig,
+				new ExecutionVertex[] { stateful1, stateless1 },
+				new ExecutionVertex[] { stateful1, stateless1 },
+				new ExecutionVertex[] { stateful1, stateless1 },
+				checkpointIDCounter,
+				store,
+				new MemoryStateBackend(),
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY,
+				failureManager);
+
+			//trigger a checkpoint and wait to become a completed checkpoint
+			assertTrue(coord.triggerCheckpoint(timestamp, false));
+
+			long checkpointId = checkpointIDCounter.getLast();
+
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
+			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
+			KeyedStateHandle serializedKeyGroupStates = generateKeyGroupState(keyGroupRange, testStates);
+
+			TaskStateSnapshot subtaskStatesForCheckpoint = new TaskStateSnapshot();
+
+			subtaskStatesForCheckpoint.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(statefulId),
+				new OperatorSubtaskState(
+					StateObjectCollection.empty(),
+					StateObjectCollection.empty(),
+					StateObjectCollection.singleton(serializedKeyGroupStates),
+					StateObjectCollection.empty()));
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStatesForCheckpoint), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId), TASK_MANAGER_LOCATION_INFO);
+
+			CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0);
+			assertEquals(jid, success.getJobId());
+
+			// trigger a savepoint and wait it to be finished
+			String savepointDir = tmpFolder.newFolder().getAbsolutePath();
+			timestamp = System.currentTimeMillis();
+			CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
+
+			KeyGroupRange keyGroupRangeForSavepoint = KeyGroupRange.of(1, 1);
+			List<SerializableObject> testStatesForSavepoint = Collections.singletonList(new SerializableObject());
+			KeyedStateHandle serializedKeyGroupStatesForSavepoint = generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
+
+			TaskStateSnapshot subtaskStatesForSavepoint = new TaskStateSnapshot();
+
+			subtaskStatesForSavepoint.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(statefulId),
+				new OperatorSubtaskState(
+					StateObjectCollection.empty(),
+					StateObjectCollection.empty(),
+					StateObjectCollection.singleton(serializedKeyGroupStatesForSavepoint),
+					StateObjectCollection.empty()));
+
+			checkpointId = checkpointIDCounter.getLast();
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStatesForSavepoint), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId), TASK_MANAGER_LOCATION_INFO);
+
+			assertTrue(savepointFuture.isDone());
+
+			//restore and jump the latest savepoint
+			coord.restoreLatestCheckpointedState(map, true, false);
+
+			//compare and see if it used the checkpoint's subtaskStates
+			BaseMatcher<JobManagerTaskRestore> matcher = new BaseMatcher<JobManagerTaskRestore>() {
+				@Override
+				public boolean matches(Object o) {
+					if (o instanceof JobManagerTaskRestore) {
+						JobManagerTaskRestore taskRestore = (JobManagerTaskRestore) o;
+						if (isPreferCheckpoint) {
+							return Objects.equals(taskRestore.getTaskStateSnapshot(), subtaskStatesForCheckpoint);
+						} else {
+							return Objects.equals(taskRestore.getTaskStateSnapshot(), subtaskStatesForSavepoint);
+						}
+					}
+					return false;
+				}
+
+				@Override
+				public void describeTo(Description description) {
+					if (isPreferCheckpoint) {
+						description.appendValue(subtaskStatesForCheckpoint);
+					} else {
+						description.appendValue(subtaskStatesForSavepoint);
+					}
+				}
+			};
+
+			verify(statefulExec1, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<JobManagerTaskRestore>any());
+
+			coord.shutdown(JobStatus.FINISHED);
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+	}
+
+	/**
+	 * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
+	 * state.
+	 *
+	 * @throws Exception
+	 */
+	private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = scaleOut ? 2 : 13;
+
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		int newParallelism2 = scaleOut ? 13 : 2;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices =
+			allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			chkConfig,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1),
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp, false);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		//vertex 1
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, null, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				taskOperatorSubtaskStates);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		//vertex 2
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism());
+		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism());
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+			expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend)));
+			expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw)));
+
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
+
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				taskOperatorSubtaskStates);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		List<KeyGroupRange> newKeyGroupPartitions2 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+
+		// rescale vertex 2
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			newParallelism2,
+			maxParallelism2);
+
+		tasks.put(jobVertexID1, newJobVertex1);
+		tasks.put(jobVertexID2, newJobVertex2);
+		coord.restoreLatestCheckpointedState(tasks, true, false);
+
+		// verify the restored state
+		verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
+		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
+
+			List<OperatorID> operatorIDs = newJobVertex2.getOperatorIDs();
+
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
+
+			JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
+			TaskStateSnapshot taskStateHandles = taskRestore.getTaskStateSnapshot();
+
+			final int headOpIndex = operatorIDs.size() - 1;
+			List<Collection<OperatorStateHandle>> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size());
+			List<Collection<OperatorStateHandle>> allParallelRawOpStates = new ArrayList<>(operatorIDs.size());
+
+			for (int idx = 0; idx < operatorIDs.size(); ++idx) {
+				OperatorID operatorID = operatorIDs.get(idx);
+				OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID);
+				Collection<OperatorStateHandle> opStateBackend = opState.getManagedOperatorState();
+				Collection<OperatorStateHandle> opStateRaw = opState.getRawOperatorState();
+				allParallelManagedOpStates.add(opStateBackend);
+				allParallelRawOpStates.add(opStateRaw);
+				if (idx == headOpIndex) {
+					Collection<KeyedStateHandle> keyedStateBackend = opState.getManagedKeyedState();
+					Collection<KeyedStateHandle> keyGroupStateRaw = opState.getRawKeyedState();
+					compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
+					compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+				}
+			}
+			actualOpStatesBackend.add(allParallelManagedOpStates);
+			actualOpStatesRaw.add(allParallelRawOpStates);
+		}
+
+		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
+		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
+	}
+
+	/**
+	 * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has
+	 * changed.
+	 *
+	 * @throws Exception
+	 */
+	@Test(expected = IllegalStateException.class)
+	public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
+		final JobID jid = new JobID();
+		final long timestamp = System.currentTimeMillis();
+
+		final JobVertexID jobVertexID1 = new JobVertexID();
+		final JobVertexID jobVertexID2 = new JobVertexID();
+		int parallelism1 = 3;
+		int parallelism2 = 2;
+		int maxParallelism1 = 42;
+		int maxParallelism2 = 13;
+
+		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			maxParallelism1);
+		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			maxParallelism2);
+
+		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
+
+		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
+		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
+
+		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			chkConfig,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			arrayExecutionVertices,
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1),
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		// trigger the checkpoint
+		coord.triggerCheckpoint(timestamp, false);
+
+		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
+		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
+
+		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
+		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				taskOperatorSubtaskStates);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
+			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
+			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
+				jid,
+				jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
+				checkpointId,
+				new CheckpointMetrics(),
+				taskOperatorSubtaskStates);
+
+			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+		}
+
+		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
+
+		assertEquals(1, completedCheckpoints.size());
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		int newMaxParallelism1 = 20;
+		int newMaxParallelism2 = 42;
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+			jobVertexID1,
+			parallelism1,
+			newMaxParallelism1);
+
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+			jobVertexID2,
+			parallelism2,
+			newMaxParallelism2);
+
+		tasks.put(jobVertexID1, newJobVertex1);
+		tasks.put(jobVertexID2, newJobVertex2);
+
+		coord.restoreLatestCheckpointedState(tasks, true, false);
+
+		fail("The restoration should have failed because the max parallelism changed.");
+	}
+
+	@Test
+	public void testStateRecoveryWhenTopologyChangeOut() throws Exception {
+		testStateRecoveryWithTopologyChange(0);
+	}
+
+	@Test
+	public void testStateRecoveryWhenTopologyChangeIn() throws Exception {
+		testStateRecoveryWithTopologyChange(1);
+	}
+
+	@Test
+	public void testStateRecoveryWhenTopologyChange() throws Exception {
+		testStateRecoveryWithTopologyChange(2);
+	}
+
+	private static Tuple2<JobVertexID, OperatorID> generateIDPair() {
+		JobVertexID jobVertexID = new JobVertexID();
+		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
+		return new Tuple2<>(jobVertexID, operatorID);
+	}
+
+	/**
+	 * <p>
+	 * old topology.
+	 * [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2
+	 * </p>
+	 *
+	 * <p>
+	 * new topology
+	 *
+	 * [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2
+	 * </p>
+	 * scaleType:
+	 * 0  increase parallelism
+	 * 1  decrease parallelism
+	 * 2  same parallelism
+	 */
+	public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception {
+
+		/*
+		 * Old topology
+		 * CHAIN(op1 -> op2) * parallelism1 -> CHAIN(op3 -> op4) * parallelism2
+		 */
+		Tuple2<JobVertexID, OperatorID> id1 = generateIDPair();
+		Tuple2<JobVertexID, OperatorID> id2 = generateIDPair();
+		int parallelism1 = 10;
+		int maxParallelism1 = 64;
+
+		Tuple2<JobVertexID, OperatorID> id3 = generateIDPair();
+		Tuple2<JobVertexID, OperatorID> id4 = generateIDPair();
+		int parallelism2 = 10;
+		int maxParallelism2 = 64;
+
+		List<KeyGroupRange> keyGroupPartitions2 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
+
+		Map<OperatorID, OperatorState> operatorStates = new HashMap<>();
+
+		//prepare vertex1 state
+		for (Tuple2<JobVertexID, OperatorID> id : Arrays.asList(id1, id2)) {
+			OperatorState taskState = new OperatorState(id.f1, parallelism1, maxParallelism1);
+			operatorStates.put(id.f1, taskState);
+			for (int index = 0; index < taskState.getParallelism(); index++) {
+				OperatorStateHandle subManagedOperatorState =
+					generatePartitionableStateHandle(id.f0, index, 2, 8, false);
+				OperatorStateHandle subRawOperatorState =
+					generatePartitionableStateHandle(id.f0, index, 2, 8, true);
+				OperatorSubtaskState subtaskState = new OperatorSubtaskState(
+					subManagedOperatorState,
+					subRawOperatorState,
+					null,
+					null);
+				taskState.putState(index, subtaskState);
+			}
+		}
+
+		List<List<ChainedStateHandle<OperatorStateHandle>>> expectedManagedOperatorStates = new ArrayList<>();
+		List<List<ChainedStateHandle<OperatorStateHandle>>> expectedRawOperatorStates = new ArrayList<>();
+		//prepare vertex2 state
+		for (Tuple2<JobVertexID, OperatorID> id : Arrays.asList(id3, id4)) {
+			OperatorState operatorState = new OperatorState(id.f1, parallelism2, maxParallelism2);
+			operatorStates.put(id.f1, operatorState);
+			List<ChainedStateHandle<OperatorStateHandle>> expectedManagedOperatorState = new ArrayList<>();
+			List<ChainedStateHandle<OperatorStateHandle>> expectedRawOperatorState = new ArrayList<>();
+			expectedManagedOperatorStates.add(expectedManagedOperatorState);
+			expectedRawOperatorStates.add(expectedRawOperatorState);
+
+			for (int index = 0; index < operatorState.getParallelism(); index++) {
+				OperatorStateHandle subManagedOperatorState =
+					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false)
+						.get(0);
+				OperatorStateHandle subRawOperatorState =
+					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true)
+						.get(0);
+				KeyGroupsStateHandle subManagedKeyedState = id.f0.equals(id3.f0)
+					? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), false)
+					: null;
+				KeyGroupsStateHandle subRawKeyedState = id.f0.equals(id3.f0)
+					? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), true)
+					: null;
+
+				expectedManagedOperatorState.add(ChainedStateHandle.wrapSingleHandle(subManagedOperatorState));
+				expectedRawOperatorState.add(ChainedStateHandle.wrapSingleHandle(subRawOperatorState));
+
+				OperatorSubtaskState subtaskState = new OperatorSubtaskState(
+					subManagedOperatorState,
+					subRawOperatorState,
+					subManagedKeyedState,
+					subRawKeyedState);
+				operatorState.putState(index, subtaskState);
+			}
+		}
+
+		/*
+		 * New topology
+		 * CHAIN(op5 -> op1 -> op2) * newParallelism1 -> CHAIN(op3 -> op6) * newParallelism2
+		 */
+		Tuple2<JobVertexID, OperatorID> id5 = generateIDPair();
+		int newParallelism1 = 10;
+
+		Tuple2<JobVertexID, OperatorID> id6 = generateIDPair();
+		int newParallelism2 = parallelism2;
+
+		if (scaleType == 0) {
+			newParallelism2 = 20;
+		} else if (scaleType == 1) {
+			newParallelism2 = 8;
+		}
+
+		List<KeyGroupRange> newKeyGroupPartitions2 =
+			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
+
+		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
+			id5.f0,
+			Arrays.asList(id2.f1, id1.f1, id5.f1),
+			newParallelism1,
+			maxParallelism1);
+
+		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
+			id3.f0,
+			Arrays.asList(id6.f1, id3.f1),
+			newParallelism2,
+			maxParallelism2);
+
+		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
+
+		tasks.put(id5.f0, newJobVertex1);
+		tasks.put(id3.f0, newJobVertex2);
+
+		JobID jobID = new JobID();
+		StandaloneCompletedCheckpointStore standaloneCompletedCheckpointStore =
+			spy(new StandaloneCompletedCheckpointStore(1));
+
+		CompletedCheckpoint completedCheckpoint = new CompletedCheckpoint(
+			jobID,
+			2,
+			System.currentTimeMillis(),
+			System.currentTimeMillis() + 3000,
+			operatorStates,
+			Collections.<MasterState>emptyList(),
+			CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
+			new TestCompletedCheckpointStorageLocation());
+
+		when(standaloneCompletedCheckpointStore.getLatestCheckpoint(false)).thenReturn(completedCheckpoint);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			new JobID(),
+			chkConfig,
+			newJobVertex1.getTaskVertices(),
+			newJobVertex1.getTaskVertices(),
+			newJobVertex1.getTaskVertices(),
+			new StandaloneCheckpointIDCounter(),
+			standaloneCompletedCheckpointStore,
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		coord.restoreLatestCheckpointedState(tasks, false, true);
+
+		for (int i = 0; i < newJobVertex1.getParallelism(); i++) {
+
+			final List<OperatorID> operatorIds = newJobVertex1.getOperatorIDs();
+
+			JobManagerTaskRestore taskRestore = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+			Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId());
+			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
+
+			OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
+			assertTrue(headOpState.getManagedKeyedState().isEmpty());
+			assertTrue(headOpState.getRawKeyedState().isEmpty());
+
+			// operator5
+			{
+				int operatorIndexInChain = 2;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
+			}
+			// operator1
+			{
+				int operatorIndexInChain = 1;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
+					id1.f0, i, 2, 8, false);
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
+					id1.f0, i, 2, 8, true);
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isStreamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isStreamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
+			}
+			// operator2
+			{
+				int operatorIndexInChain = 0;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
+					id2.f0, i, 2, 8, false);
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
+					id2.f0, i, 2, 8, true);
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isStreamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isStreamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
+			}
+		}
+
+		List<List<Collection<OperatorStateHandle>>> actualManagedOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
+		List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
+
+		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
+
+			final List<OperatorID> operatorIds = newJobVertex2.getOperatorIDs();
+
+			JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+			Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId());
+			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
+
+			// operator 3
+			{
+				int operatorIndexInChain = 1;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1);
+				actualSubManagedOperatorState.add(opState.getManagedOperatorState());
+
+				List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1);
+				actualSubRawOperatorState.add(opState.getRawOperatorState());
+
+				actualManagedOperatorStates.add(actualSubManagedOperatorState);
+				actualRawOperatorStates.add(actualSubRawOperatorState);
+			}
+
+			// operator 6
+			{
+				int operatorIndexInChain = 0;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
+
+			}
+
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true);
+
+			OperatorSubtaskState headOpState =
+				stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
+
+			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
+			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
+
+			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
+			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+		}
+
+		comparePartitionableState(expectedManagedOperatorStates.get(0), actualManagedOperatorStates);
+		comparePartitionableState(expectedRawOperatorStates.get(0), actualRawOperatorStates);
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 034fac4..2d86a06 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -33,11 +33,9 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.OperatorStreamStateHandle;
@@ -49,25 +47,17 @@ import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
-import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.util.ExceptionUtils;
-import org.apache.flink.util.SerializableObject;
 import org.apache.flink.util.TestLogger;
 
 import org.apache.flink.shaded.guava18.com.google.common.collect.Iterables;
 
-import org.hamcrest.BaseMatcher;
-import org.hamcrest.Description;
 import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
-import org.mockito.Mockito;
-import org.mockito.hamcrest.MockitoHamcrest;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 import org.mockito.verification.VerificationMode;
 
 import java.io.IOException;
@@ -79,26 +69,15 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Objects;
 import java.util.Random;
 import java.util.UUID;
-import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
 import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -1306,194 +1285,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	}
 
 	@Test
-	public void testPeriodicTriggering() {
-		try {
-			final JobID jid = new JobID();
-			final long start = System.currentTimeMillis();
-
-			// create some mock execution vertices and trigger some checkpoint
-
-			final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
-			final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
-			final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
-
-			ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
-			ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
-			ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
-
-			final AtomicInteger numCalls = new AtomicInteger();
-
-			final Execution execution = triggerVertex.getCurrentExecutionAttempt();
-
-			doAnswer(new Answer<Void>() {
-
-				private long lastId = -1;
-				private long lastTs = -1;
-
-				@Override
-				public Void answer(InvocationOnMock invocation) throws Throwable {
-					long id = (Long) invocation.getArguments()[0];
-					long ts = (Long) invocation.getArguments()[1];
-
-					assertTrue(id > lastId);
-					assertTrue(ts >= lastTs);
-					assertTrue(ts >= start);
-
-					lastId = id;
-					lastTs = ts;
-					numCalls.incrementAndGet();
-					return null;
-				}
-			}).when(execution).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
-
-			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-				10,        // periodic interval is 10 ms
-				200000,    // timeout is very long (200 s)
-				0,
-				Integer.MAX_VALUE,
-				CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-				true,
-				false,
-				0);
-			CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid,
-				chkConfig,
-				new ExecutionVertex[] { triggerVertex },
-				new ExecutionVertex[] { ackVertex },
-				new ExecutionVertex[] { commitVertex },
-				new StandaloneCheckpointIDCounter(),
-				new StandaloneCompletedCheckpointStore(2),
-				new MemoryStateBackend(),
-				Executors.directExecutor(),
-				SharedStateRegistry.DEFAULT_FACTORY,
-				failureManager);
-
-			coord.startCheckpointScheduler();
-
-			long timeout = System.currentTimeMillis() + 60000;
-			do {
-				Thread.sleep(20);
-			}
-			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
-			assertTrue(numCalls.get() >= 5);
-
-			coord.stopCheckpointScheduler();
-
-			// for 400 ms, no further calls may come.
-			// there may be the case that one trigger was fired and about to
-			// acquire the lock, such that after cancelling it will still do
-			// the remainder of its work
-			int numCallsSoFar = numCalls.get();
-			Thread.sleep(400);
-			assertTrue(numCallsSoFar == numCalls.get() ||
-					numCallsSoFar + 1 == numCalls.get());
-
-			// start another sequence of periodic scheduling
-			numCalls.set(0);
-			coord.startCheckpointScheduler();
-
-			timeout = System.currentTimeMillis() + 60000;
-			do {
-				Thread.sleep(20);
-			}
-			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
-			assertTrue(numCalls.get() >= 5);
-
-			coord.stopCheckpointScheduler();
-
-			// for 400 ms, no further calls may come
-			// there may be the case that one trigger was fired and about to
-			// acquire the lock, such that after cancelling it will still do
-			// the remainder of its work
-			numCallsSoFar = numCalls.get();
-			Thread.sleep(400);
-			assertTrue(numCallsSoFar == numCalls.get() ||
-					numCallsSoFar + 1 == numCalls.get());
-
-			coord.shutdown(JobStatus.FINISHED);
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
-	}
-
-	/**
-	 * This test verified that after a completed checkpoint a certain time has passed before
-	 * another is triggered.
-	 */
-	@Test
-	public void testMinTimeBetweenCheckpointsInterval() throws Exception {
-		final JobID jid = new JobID();
-
-		// create some mock execution vertices and trigger some checkpoint
-		final ExecutionAttemptID attemptID = new ExecutionAttemptID();
-		final ExecutionVertex vertex = mockExecutionVertex(attemptID);
-		final Execution executionAttempt = vertex.getCurrentExecutionAttempt();
-
-		final BlockingQueue<Long> triggerCalls = new LinkedBlockingQueue<>();
-
-		doAnswer(invocation -> {
-			triggerCalls.add((Long) invocation.getArguments()[0]);
-			return null;
-		}).when(executionAttempt).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
-
-		final long delay = 50;
-
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			12,           // periodic interval is 12 ms
-			200_000,     // timeout is very long (200 s)
-			delay,       // 50 ms delay between checkpoints
-			1,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		final CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid,
-				chkConfig,
-				new ExecutionVertex[] { vertex },
-				new ExecutionVertex[] { vertex },
-				new ExecutionVertex[] { vertex },
-				new StandaloneCheckpointIDCounter(),
-				new StandaloneCompletedCheckpointStore(2),
-				new MemoryStateBackend(),
-				Executors.directExecutor(),
-				SharedStateRegistry.DEFAULT_FACTORY,
-				failureManager);
-
-		try {
-			coord.startCheckpointScheduler();
-
-			// wait until the first checkpoint was triggered
-			Long firstCallId = triggerCalls.take();
-			assertEquals(1L, firstCallId.longValue());
-
-			AcknowledgeCheckpoint ackMsg = new AcknowledgeCheckpoint(jid, attemptID, 1L);
-
-			// tell the coordinator that the checkpoint is done
-			final long ackTime = System.nanoTime();
-			coord.receiveAcknowledgeMessage(ackMsg, TASK_MANAGER_LOCATION_INFO);
-
-			// wait until the next checkpoint is triggered
-			Long nextCallId = triggerCalls.take();
-			final long nextCheckpointTime = System.nanoTime();
-			assertEquals(2L, nextCallId.longValue());
-
-			final long delayMillis = (nextCheckpointTime - ackTime) / 1_000_000;
-
-			// we need to add one ms here to account for rounding errors
-			if (delayMillis + 1 < delay) {
-				fail("checkpoint came too early: delay was " + delayMillis + " but should have been at least " + delay);
-			}
-		}
-		finally {
-			coord.stopCheckpointScheduler();
-			coord.shutdown(JobStatus.FINISHED);
-		}
-	}
-
-	@Test
 	public void testMaxConcurrentAttempts1() {
 		testMaxConcurrentAttempts(1);
 	}
@@ -2072,961 +1863,67 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	}
 
 	/**
-	 * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to
-	 * the {@link Execution} upon recovery.
-	 *
-	 * @throws Exception
+	 * Tests that the externalized checkpoint configuration is respected.
 	 */
 	@Test
-	public void testRestoreLatestCheckpointedState() throws Exception {
-		final JobID jid = new JobID();
-		final long timestamp = System.currentTimeMillis();
-
-		final JobVertexID jobVertexID1 = new JobVertexID();
-		final JobVertexID jobVertexID2 = new JobVertexID();
-		int parallelism1 = 3;
-		int parallelism2 = 2;
-		int maxParallelism1 = 42;
-		int maxParallelism2 = 13;
-
-		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
-			jobVertexID1,
-			parallelism1,
-			maxParallelism1);
-		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
-			jobVertexID2,
-			parallelism2,
-			maxParallelism2);
-
-		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
-
-		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
-
-		ExecutionVertex[] arrayExecutionVertices =
-				allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
-
-		CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore();
-
-		// set up the coordinator and validate the initial state
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			600000,
-			600000,
-			0,
-			Integer.MAX_VALUE,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		CheckpointCoordinator coord = new CheckpointCoordinator(
-			jid,
-			chkConfig,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			new StandaloneCheckpointIDCounter(),
-			store,
-			new MemoryStateBackend(),
-			Executors.directExecutor(),
-			SharedStateRegistry.DEFAULT_FACTORY,
-			failureManager);
-
-		// trigger the checkpoint
-		coord.triggerCheckpoint(timestamp, false);
-
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
-		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
-
-		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
+	public void testExternalizedCheckpoints() throws Exception {
+		try {
+			final JobID jid = new JobID();
+			final long timestamp = System.currentTimeMillis();
 
-		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
+			// create some mock Execution vertices that receive the checkpoint trigger messages
+			final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
+			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-					subtaskState);
+			// set up the coordinator and validate the initial state
+			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+				600000,
+				600000,
+				0,
+				Integer.MAX_VALUE,
+				CheckpointRetentionPolicy.RETAIN_ON_FAILURE,
+				true,
+				false,
+				0);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				chkConfig,
+				new ExecutionVertex[] { vertex1 },
+				new ExecutionVertex[] { vertex1 },
+				new ExecutionVertex[] { vertex1 },
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(1),
+				new MemoryStateBackend(),
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY,
+				failureManager);
 
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
-		}
+			assertTrue(coord.triggerCheckpoint(timestamp, false));
 
-		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+			for (PendingCheckpoint checkpoint : coord.getPendingCheckpoints().values()) {
+				CheckpointProperties props = checkpoint.getProps();
+				CheckpointProperties expected = CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.RETAIN_ON_FAILURE);
 
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-					subtaskState);
+				assertEquals(expected, props);
+			}
 
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
+			// the now we should have a completed checkpoint
+			coord.shutdown(JobStatus.FINISHED);
 		}
-
-		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
-
-		assertEquals(1, completedCheckpoints.size());
-
-		// shutdown the store
-		store.shutdown(JobStatus.SUSPENDED);
-
-		// restore the store
-		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
-
-		tasks.put(jobVertexID1, jobVertex1);
-		tasks.put(jobVertexID2, jobVertex2);
-
-		coord.restoreLatestCheckpointedState(tasks, true, false);
-
-		// validate that all shared states are registered again after the recovery.
-		for (CompletedCheckpoint completedCheckpoint : completedCheckpoints) {
-			for (OperatorState taskState : completedCheckpoint.getOperatorStates().values()) {
-				for (OperatorSubtaskState subtaskState : taskState.getStates()) {
-					verify(subtaskState, times(2)).registerSharedStates(any(SharedStateRegistry.class));
-				}
-			}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
 		}
-
-		// verify the restored state
-		verifyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1);
-		verifyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2);
 	}
 
-	/**
-	 * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has
-	 * changed.
-	 *
-	 * @throws Exception
-	 */
-	@Test(expected = IllegalStateException.class)
-	public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
-		final JobID jid = new JobID();
-		final long timestamp = System.currentTimeMillis();
-
-		final JobVertexID jobVertexID1 = new JobVertexID();
-		final JobVertexID jobVertexID2 = new JobVertexID();
-		int parallelism1 = 3;
-		int parallelism2 = 2;
-		int maxParallelism1 = 42;
-		int maxParallelism2 = 13;
-
-		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
-			jobVertexID1,
-			parallelism1,
-			maxParallelism1);
-		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
-			jobVertexID2,
-			parallelism2,
-			maxParallelism2);
-
-		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
-
-		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
-
-		ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
-
-		// set up the coordinator and validate the initial state
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			600000,
-			600000,
-			0,
-			Integer.MAX_VALUE,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		CheckpointCoordinator coord = new CheckpointCoordinator(
-			jid,
-			chkConfig,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			new StandaloneCheckpointIDCounter(),
-			new StandaloneCompletedCheckpointStore(1),
-			new MemoryStateBackend(),
-			Executors.directExecutor(),
-			SharedStateRegistry.DEFAULT_FACTORY,
-			failureManager);
-
-		// trigger the checkpoint
-		coord.triggerCheckpoint(timestamp, false);
-
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
-		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
-
-		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
-
-		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
-			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
-			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
-			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-				taskOperatorSubtaskStates);
-
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
-		}
-
-		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
-			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
-			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-					taskOperatorSubtaskStates);
-
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
-		}
-
-		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
-
-		assertEquals(1, completedCheckpoints.size());
-
-		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
-
-		int newMaxParallelism1 = 20;
-		int newMaxParallelism2 = 42;
-
-		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
-			jobVertexID1,
-			parallelism1,
-			newMaxParallelism1);
-
-		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
-			jobVertexID2,
-			parallelism2,
-			newMaxParallelism2);
-
-		tasks.put(jobVertexID1, newJobVertex1);
-		tasks.put(jobVertexID2, newJobVertex2);
-
-		coord.restoreLatestCheckpointedState(tasks, true, false);
-
-		fail("The restoration should have failed because the max parallelism changed.");
-	}
-
-	@Test
-	public void testRestoreLatestCheckpointedStateScaleIn() throws Exception {
-		testRestoreLatestCheckpointedStateWithChangingParallelism(false);
-	}
-
-	@Test
-	public void testRestoreLatestCheckpointedStateScaleOut() throws Exception {
-		testRestoreLatestCheckpointedStateWithChangingParallelism(true);
-	}
-
-	@Test
-	public void testRestoreLatestCheckpointWhenPreferCheckpoint() throws Exception {
-		testRestoreLatestCheckpointIsPreferSavepoint(true);
-	}
-
-	@Test
-	public void testRestoreLatestCheckpointWhenPreferSavepoint() throws Exception {
-		testRestoreLatestCheckpointIsPreferSavepoint(false);
-	}
-
-	private void testRestoreLatestCheckpointIsPreferSavepoint(boolean isPreferCheckpoint) {
-		try {
-			final JobID jid = new JobID();
-			long timestamp = System.currentTimeMillis();
-			StandaloneCheckpointIDCounter checkpointIDCounter = new StandaloneCheckpointIDCounter();
-
-			final JobVertexID statefulId = new JobVertexID();
-			final JobVertexID statelessId = new JobVertexID();
-
-			Execution statefulExec1 = mockExecution();
-			Execution statelessExec1 = mockExecution();
-
-			ExecutionVertex stateful1 = mockExecutionVertex(statefulExec1, statefulId, 0, 1);
-			ExecutionVertex stateless1 = mockExecutionVertex(statelessExec1, statelessId, 0, 1);
-
-			ExecutionJobVertex stateful = mockExecutionJobVertex(statefulId,
-				new ExecutionVertex[] { stateful1 });
-			ExecutionJobVertex stateless = mockExecutionJobVertex(statelessId,
-				new ExecutionVertex[] { stateless1 });
-
-			Map<JobVertexID, ExecutionJobVertex> map = new HashMap<JobVertexID, ExecutionJobVertex>();
-			map.put(statefulId, stateful);
-			map.put(statelessId, stateless);
-
-			CompletedCheckpointStore store = new RecoverableCompletedCheckpointStore(2);
-
-			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-				600000,
-				600000,
-				0,
-				Integer.MAX_VALUE,
-				CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-				true,
-				isPreferCheckpoint,
-				0);
-			CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid,
-				chkConfig,
-				new ExecutionVertex[] { stateful1, stateless1 },
-				new ExecutionVertex[] { stateful1, stateless1 },
-				new ExecutionVertex[] { stateful1, stateless1 },
-				checkpointIDCounter,
-				store,
-				new MemoryStateBackend(),
-				Executors.directExecutor(),
-				SharedStateRegistry.DEFAULT_FACTORY,
-				failureManager);
-
-			//trigger a checkpoint and wait to become a completed checkpoint
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
-
-			long checkpointId = checkpointIDCounter.getLast();
-
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
-			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStates = generateKeyGroupState(keyGroupRange, testStates);
-
-			TaskStateSnapshot subtaskStatesForCheckpoint = new TaskStateSnapshot();
-
-			subtaskStatesForCheckpoint.putSubtaskStateByOperatorID(
-				OperatorID.fromJobVertexID(statefulId),
-				new OperatorSubtaskState(
-					StateObjectCollection.empty(),
-					StateObjectCollection.empty(),
-					StateObjectCollection.singleton(serializedKeyGroupStates),
-					StateObjectCollection.empty()));
-
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStatesForCheckpoint), TASK_MANAGER_LOCATION_INFO);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId), TASK_MANAGER_LOCATION_INFO);
-
-			CompletedCheckpoint success = coord.getSuccessfulCheckpoints().get(0);
-			assertEquals(jid, success.getJobId());
-
-			// trigger a savepoint and wait it to be finished
-			String savepointDir = tmpFolder.newFolder().getAbsolutePath();
-			timestamp = System.currentTimeMillis();
-			CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
-
-			KeyGroupRange keyGroupRangeForSavepoint = KeyGroupRange.of(1, 1);
-			List<SerializableObject> testStatesForSavepoint = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStatesForSavepoint = generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
-
-			TaskStateSnapshot subtaskStatesForSavepoint = new TaskStateSnapshot();
-
-			subtaskStatesForSavepoint.putSubtaskStateByOperatorID(
-				OperatorID.fromJobVertexID(statefulId),
-				new OperatorSubtaskState(
-					StateObjectCollection.empty(),
-					StateObjectCollection.empty(),
-					StateObjectCollection.singleton(serializedKeyGroupStatesForSavepoint),
-					StateObjectCollection.empty()));
-
-			checkpointId = checkpointIDCounter.getLast();
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStatesForSavepoint), TASK_MANAGER_LOCATION_INFO);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId), TASK_MANAGER_LOCATION_INFO);
-
-			assertTrue(savepointFuture.isDone());
-
-			//restore and jump the latest savepoint
-			coord.restoreLatestCheckpointedState(map, true, false);
-
-			//compare and see if it used the checkpoint's subtaskStates
-			BaseMatcher<JobManagerTaskRestore> matcher = new BaseMatcher<JobManagerTaskRestore>() {
-				@Override
-				public boolean matches(Object o) {
-					if (o instanceof JobManagerTaskRestore) {
-						JobManagerTaskRestore taskRestore = (JobManagerTaskRestore) o;
-						if (isPreferCheckpoint) {
-							return Objects.equals(taskRestore.getTaskStateSnapshot(), subtaskStatesForCheckpoint);
-						} else {
-							return Objects.equals(taskRestore.getTaskStateSnapshot(), subtaskStatesForSavepoint);
-						}
-					}
-					return false;
-				}
-
-				@Override
-				public void describeTo(Description description) {
-					if (isPreferCheckpoint) {
-						description.appendValue(subtaskStatesForCheckpoint);
-					} else {
-						description.appendValue(subtaskStatesForSavepoint);
-					}
-				}
-			};
-
-			verify(statefulExec1, times(1)).setInitialState(MockitoHamcrest.argThat(matcher));
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<JobManagerTaskRestore>any());
-
-			coord.shutdown(JobStatus.FINISHED);
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
-	}
-
-	@Test
-	public void testStateRecoveryWhenTopologyChangeOut() throws Exception {
-		testStateRecoveryWithTopologyChange(0);
-	}
-
-	@Test
-	public void testStateRecoveryWhenTopologyChangeIn() throws Exception {
-		testStateRecoveryWithTopologyChange(1);
-	}
-
-	@Test
-	public void testStateRecoveryWhenTopologyChange() throws Exception {
-		testStateRecoveryWithTopologyChange(2);
-	}
-
-
-	/**
-	 * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned
-	 * state.
-	 *
-	 * @throws Exception
-	 */
-	private void testRestoreLatestCheckpointedStateWithChangingParallelism(boolean scaleOut) throws Exception {
-		final JobID jid = new JobID();
-		final long timestamp = System.currentTimeMillis();
-
-		final JobVertexID jobVertexID1 = new JobVertexID();
-		final JobVertexID jobVertexID2 = new JobVertexID();
-		int parallelism1 = 3;
-		int parallelism2 = scaleOut ? 2 : 13;
-
-		int maxParallelism1 = 42;
-		int maxParallelism2 = 13;
-
-		int newParallelism2 = scaleOut ? 13 : 2;
-
-		final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex(
-				jobVertexID1,
-				parallelism1,
-				maxParallelism1);
-		final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex(
-				jobVertexID2,
-				parallelism2,
-				maxParallelism2);
-
-		List<ExecutionVertex> allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2);
-
-		allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices()));
-		allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices()));
-
-		ExecutionVertex[] arrayExecutionVertices =
-				allExecutionVertices.toArray(new ExecutionVertex[allExecutionVertices.size()]);
-
-		// set up the coordinator and validate the initial state
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			600000,
-			600000,
-			0,
-			Integer.MAX_VALUE,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		CheckpointCoordinator coord = new CheckpointCoordinator(
-			jid,
-			chkConfig,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			arrayExecutionVertices,
-			new StandaloneCheckpointIDCounter(),
-			new StandaloneCompletedCheckpointStore(1),
-			new MemoryStateBackend(),
-			Executors.directExecutor(),
-			SharedStateRegistry.DEFAULT_FACTORY,
-			failureManager);
-
-		// trigger the checkpoint
-		coord.triggerCheckpoint(timestamp, false);
-
-		assertTrue(coord.getPendingCheckpoints().keySet().size() == 1);
-		long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet());
-
-		List<KeyGroupRange> keyGroupPartitions1 =
-				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
-		List<KeyGroupRange> keyGroupPartitions2 =
-				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
-
-		//vertex 1
-		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false);
-			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
-			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
-			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, null, keyedStateBackend, keyedStateRaw);
-			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
-			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
-
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-					taskOperatorSubtaskStates);
-
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
-		}
-
-		//vertex 2
-		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesBackend = new ArrayList<>(jobVertex2.getParallelism());
-		final List<ChainedStateHandle<OperatorStateHandle>> expectedOpStatesRaw = new ArrayList<>(jobVertex2.getParallelism());
-		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
-			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false);
-			OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true);
-			expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend)));
-			expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw)));
-
-			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
-			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
-			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
-
-			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
-					jid,
-					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
-					checkpointId,
-					new CheckpointMetrics(),
-					taskOperatorSubtaskStates);
-
-			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
-		}
-
-		List<CompletedCheckpoint> completedCheckpoints = coord.getSuccessfulCheckpoints();
-
-		assertEquals(1, completedCheckpoints.size());
-
-		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
-
-		List<KeyGroupRange> newKeyGroupPartitions2 =
-				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
-
-		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
-				jobVertexID1,
-				parallelism1,
-				maxParallelism1);
-
-		// rescale vertex 2
-		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
-				jobVertexID2,
-				newParallelism2,
-				maxParallelism2);
-
-		tasks.put(jobVertexID1, newJobVertex1);
-		tasks.put(jobVertexID2, newJobVertex2);
-		coord.restoreLatestCheckpointedState(tasks, true, false);
-
-		// verify the restored state
-		verifyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1);
-		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
-		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
-		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-
-			List<OperatorID> operatorIDs = newJobVertex2.getOperatorIDs();
-
-			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
-			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
-
-			JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
-			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
-			TaskStateSnapshot taskStateHandles = taskRestore.getTaskStateSnapshot();
-
-			final int headOpIndex = operatorIDs.size() - 1;
-			List<Collection<OperatorStateHandle>> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size());
-			List<Collection<OperatorStateHandle>> allParallelRawOpStates = new ArrayList<>(operatorIDs.size());
-
-			for (int idx = 0; idx < operatorIDs.size(); ++idx) {
-				OperatorID operatorID = operatorIDs.get(idx);
-				OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID);
-				Collection<OperatorStateHandle> opStateBackend = opState.getManagedOperatorState();
-				Collection<OperatorStateHandle> opStateRaw = opState.getRawOperatorState();
-				allParallelManagedOpStates.add(opStateBackend);
-				allParallelRawOpStates.add(opStateRaw);
-				if (idx == headOpIndex) {
-					Collection<KeyedStateHandle> keyedStateBackend = opState.getManagedKeyedState();
-					Collection<KeyedStateHandle> keyGroupStateRaw = opState.getRawKeyedState();
-					compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
-					compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
-				}
-			}
-			actualOpStatesBackend.add(allParallelManagedOpStates);
-			actualOpStatesRaw.add(allParallelRawOpStates);
-		}
-
-		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
-		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
-	}
-
-	private static Tuple2<JobVertexID, OperatorID> generateIDPair() {
-		JobVertexID jobVertexID = new JobVertexID();
-		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
-		return new Tuple2<>(jobVertexID, operatorID);
-	}
-
-	/**
-	 * <p>
-	 * old topology.
-	 * [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2
-	 * </p>
-	 *
-	 * <p>
-	 * new topology
-	 *
-	 * [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2
-	 * </p>
-	 * scaleType:
-	 * 0  increase parallelism
-	 * 1  decrease parallelism
-	 * 2  same parallelism
-	 */
-	public void testStateRecoveryWithTopologyChange(int scaleType) throws Exception {
-
-		/*
-		 * Old topology
-		 * CHAIN(op1 -> op2) * parallelism1 -> CHAIN(op3 -> op4) * parallelism2
-		 */
-		Tuple2<JobVertexID, OperatorID> id1 = generateIDPair();
-		Tuple2<JobVertexID, OperatorID> id2 = generateIDPair();
-		int parallelism1 = 10;
-		int maxParallelism1 = 64;
-
-		Tuple2<JobVertexID, OperatorID> id3 = generateIDPair();
-		Tuple2<JobVertexID, OperatorID> id4 = generateIDPair();
-		int parallelism2 = 10;
-		int maxParallelism2 = 64;
-
-		List<KeyGroupRange> keyGroupPartitions2 =
-			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
-
-		Map<OperatorID, OperatorState> operatorStates = new HashMap<>();
-
-		//prepare vertex1 state
-		for (Tuple2<JobVertexID, OperatorID> id : Arrays.asList(id1, id2)) {
-			OperatorState taskState = new OperatorState(id.f1, parallelism1, maxParallelism1);
-			operatorStates.put(id.f1, taskState);
-			for (int index = 0; index < taskState.getParallelism(); index++) {
-				OperatorStateHandle subManagedOperatorState =
-					generatePartitionableStateHandle(id.f0, index, 2, 8, false);
-				OperatorStateHandle subRawOperatorState =
-					generatePartitionableStateHandle(id.f0, index, 2, 8, true);
-				OperatorSubtaskState subtaskState = new OperatorSubtaskState(
-					subManagedOperatorState,
-					subRawOperatorState,
-					null,
-					null);
-				taskState.putState(index, subtaskState);
-			}
-		}
-
-		List<List<ChainedStateHandle<OperatorStateHandle>>> expectedManagedOperatorStates = new ArrayList<>();
-		List<List<ChainedStateHandle<OperatorStateHandle>>> expectedRawOperatorStates = new ArrayList<>();
-		//prepare vertex2 state
-		for (Tuple2<JobVertexID, OperatorID> id : Arrays.asList(id3, id4)) {
-			OperatorState operatorState = new OperatorState(id.f1, parallelism2, maxParallelism2);
-			operatorStates.put(id.f1, operatorState);
-			List<ChainedStateHandle<OperatorStateHandle>> expectedManagedOperatorState = new ArrayList<>();
-			List<ChainedStateHandle<OperatorStateHandle>> expectedRawOperatorState = new ArrayList<>();
-			expectedManagedOperatorStates.add(expectedManagedOperatorState);
-			expectedRawOperatorStates.add(expectedRawOperatorState);
-
-			for (int index = 0; index < operatorState.getParallelism(); index++) {
-				OperatorStateHandle subManagedOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false)
-						.get(0);
-				OperatorStateHandle subRawOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true)
-						.get(0);
-				KeyGroupsStateHandle subManagedKeyedState = id.f0.equals(id3.f0)
-					? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), false)
-					: null;
-				KeyGroupsStateHandle subRawKeyedState = id.f0.equals(id3.f0)
-					? generateKeyGroupState(id.f0, keyGroupPartitions2.get(index), true)
-					: null;
-
-				expectedManagedOperatorState.add(ChainedStateHandle.wrapSingleHandle(subManagedOperatorState));
-				expectedRawOperatorState.add(ChainedStateHandle.wrapSingleHandle(subRawOperatorState));
-
-				OperatorSubtaskState subtaskState = new OperatorSubtaskState(
-					subManagedOperatorState,
-					subRawOperatorState,
-					subManagedKeyedState,
-					subRawKeyedState);
-				operatorState.putState(index, subtaskState);
-			}
-		}
-
-		/*
-		 * New topology
-		 * CHAIN(op5 -> op1 -> op2) * newParallelism1 -> CHAIN(op3 -> op6) * newParallelism2
-		 */
-		Tuple2<JobVertexID, OperatorID> id5 = generateIDPair();
-		int newParallelism1 = 10;
-
-		Tuple2<JobVertexID, OperatorID> id6 = generateIDPair();
-		int newParallelism2 = parallelism2;
-
-		if (scaleType == 0) {
-			newParallelism2 = 20;
-		} else if (scaleType == 1) {
-			newParallelism2 = 8;
-		}
-
-		List<KeyGroupRange> newKeyGroupPartitions2 =
-			StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, newParallelism2);
-
-		final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex(
-			id5.f0,
-			Arrays.asList(id2.f1, id1.f1, id5.f1),
-			newParallelism1,
-			maxParallelism1);
-
-		final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex(
-			id3.f0,
-			Arrays.asList(id6.f1, id3.f1),
-			newParallelism2,
-			maxParallelism2);
-
-		Map<JobVertexID, ExecutionJobVertex> tasks = new HashMap<>();
-
-		tasks.put(id5.f0, newJobVertex1);
-		tasks.put(id3.f0, newJobVertex2);
-
-		JobID jobID = new JobID();
-		StandaloneCompletedCheckpointStore standaloneCompletedCheckpointStore =
-			spy(new StandaloneCompletedCheckpointStore(1));
-
-		CompletedCheckpoint completedCheckpoint = new CompletedCheckpoint(
-				jobID,
-				2,
-				System.currentTimeMillis(),
-				System.currentTimeMillis() + 3000,
-				operatorStates,
-				Collections.<MasterState>emptyList(),
-				CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-				new TestCompletedCheckpointStorageLocation());
-
-		when(standaloneCompletedCheckpointStore.getLatestCheckpoint(false)).thenReturn(completedCheckpoint);
-
-		// set up the coordinator and validate the initial state
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			600000,
-			600000,
-			0,
-			Integer.MAX_VALUE,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		CheckpointCoordinator coord = new CheckpointCoordinator(
-			new JobID(),
-			chkConfig,
-			newJobVertex1.getTaskVertices(),
-			newJobVertex1.getTaskVertices(),
-			newJobVertex1.getTaskVertices(),
-			new StandaloneCheckpointIDCounter(),
-			standaloneCompletedCheckpointStore,
-			new MemoryStateBackend(),
-			Executors.directExecutor(),
-			SharedStateRegistry.DEFAULT_FACTORY,
-			failureManager);
-
-		coord.restoreLatestCheckpointedState(tasks, false, true);
-
-		for (int i = 0; i < newJobVertex1.getParallelism(); i++) {
-
-			final List<OperatorID> operatorIds = newJobVertex1.getOperatorIDs();
-
-			JobManagerTaskRestore taskRestore = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
-			Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId());
-			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
-
-			OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
-			assertTrue(headOpState.getManagedKeyedState().isEmpty());
-			assertTrue(headOpState.getRawKeyedState().isEmpty());
-
-			// operator5
-			{
-				int operatorIndexInChain = 2;
-				OperatorSubtaskState opState =
-					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
-
-				assertTrue(opState.getManagedOperatorState().isEmpty());
-				assertTrue(opState.getRawOperatorState().isEmpty());
-			}
-			// operator1
-			{
-				int operatorIndexInChain = 1;
-				OperatorSubtaskState opState =
-					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
-
-				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
-					id1.f0, i, 2, 8, false);
-				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
-					id1.f0, i, 2, 8, true);
-
-				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
-				assertEquals(1, managedOperatorState.size());
-				assertTrue(CommonTestUtils.isStreamContentEqual(expectedManagedOpState.openInputStream(),
-					managedOperatorState.iterator().next().openInputStream()));
-
-				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
-				assertEquals(1, rawOperatorState.size());
-				assertTrue(CommonTestUtils.isStreamContentEqual(expectedRawOpState.openInputStream(),
-					rawOperatorState.iterator().next().openInputStream()));
-			}
-			// operator2
-			{
-				int operatorIndexInChain = 0;
-				OperatorSubtaskState opState =
-					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
-
-				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
-					id2.f0, i, 2, 8, false);
-				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
-					id2.f0, i, 2, 8, true);
-
-				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
-				assertEquals(1, managedOperatorState.size());
-				assertTrue(CommonTestUtils.isStreamContentEqual(expectedManagedOpState.openInputStream(),
-					managedOperatorState.iterator().next().openInputStream()));
-
-				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
-				assertEquals(1, rawOperatorState.size());
-				assertTrue(CommonTestUtils.isStreamContentEqual(expectedRawOpState.openInputStream(),
-					rawOperatorState.iterator().next().openInputStream()));
-			}
-		}
-
-		List<List<Collection<OperatorStateHandle>>> actualManagedOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
-		List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
-
-		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-
-			final List<OperatorID> operatorIds = newJobVertex2.getOperatorIDs();
-
-			JobManagerTaskRestore taskRestore = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
-			Assert.assertEquals(2L, taskRestore.getRestoreCheckpointId());
-			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
-
-			// operator 3
-			{
-				int operatorIndexInChain = 1;
-				OperatorSubtaskState opState =
-					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
-
-				List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1);
-				actualSubManagedOperatorState.add(opState.getManagedOperatorState());
-
-				List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1);
-				actualSubRawOperatorState.add(opState.getRawOperatorState());
-
-				actualManagedOperatorStates.add(actualSubManagedOperatorState);
-				actualRawOperatorStates.add(actualSubRawOperatorState);
-			}
-
-			// operator 6
-			{
-				int operatorIndexInChain = 0;
-				OperatorSubtaskState opState =
-					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
-				assertTrue(opState.getManagedOperatorState().isEmpty());
-				assertTrue(opState.getRawOperatorState().isEmpty());
-
-			}
-
-			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false);
-			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true);
-
-			OperatorSubtaskState headOpState =
-				stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
-
-			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
-			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
-
-			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
-			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
-		}
-
-		comparePartitionableState(expectedManagedOperatorStates.get(0), actualManagedOperatorStates);
-		comparePartitionableState(expectedRawOperatorStates.get(0), actualRawOperatorStates);
-	}
-
-	/**
-	 * Tests that the externalized checkpoint configuration is respected.
-	 */
-	@Test
-	public void testExternalizedCheckpoints() throws Exception {
-		try {
-			final JobID jid = new JobID();
-			final long timestamp = System.currentTimeMillis();
-
-			// create some mock Execution vertices that receive the checkpoint trigger messages
-			final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-
-			// set up the coordinator and validate the initial state
-			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-				600000,
-				600000,
-				0,
-				Integer.MAX_VALUE,
-				CheckpointRetentionPolicy.RETAIN_ON_FAILURE,
-				true,
-				false,
-				0);
-			CheckpointCoordinator coord = new CheckpointCoordinator(
-				jid,
-				chkConfig,
-				new ExecutionVertex[] { vertex1 },
-				new ExecutionVertex[] { vertex1 },
-				new ExecutionVertex[] { vertex1 },
-				new StandaloneCheckpointIDCounter(),
-				new StandaloneCompletedCheckpointStore(1),
-				new MemoryStateBackend(),
-				Executors.directExecutor(),
-				SharedStateRegistry.DEFAULT_FACTORY,
-				failureManager);
-
-			assertTrue(coord.triggerCheckpoint(timestamp, false));
-
-			for (PendingCheckpoint checkpoint : coord.getPendingCheckpoints().values()) {
-				CheckpointProperties props = checkpoint.getProps();
-				CheckpointProperties expected = CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.RETAIN_ON_FAILURE);
-
-				assertEquals(expected, props);
-			}
-
-			// the now we should have a completed checkpoint
-			coord.shutdown(JobStatus.FINISHED);
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
-	}
-
-	@Test
-	public void testCreateKeyGroupPartitions() {
-		testCreateKeyGroupPartitions(1, 1);
-		testCreateKeyGroupPartitions(13, 1);
-		testCreateKeyGroupPartitions(13, 2);
-		testCreateKeyGroupPartitions(Short.MAX_VALUE, 1);
-		testCreateKeyGroupPartitions(Short.MAX_VALUE, 13);
-		testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE);
+	@Test
+	public void testCreateKeyGroupPartitions() {
+		testCreateKeyGroupPartitions(1, 1);
+		testCreateKeyGroupPartitions(13, 1);
+		testCreateKeyGroupPartitions(13, 2);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, 1);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, 13);
+		testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE);
 
 		Random r = new Random(1234);
 		for (int k = 0; k < 1000; ++k) {
@@ -3036,61 +1933,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		}
 	}
 
-	@Test
-	public void testStopPeriodicScheduler() throws Exception {
-		// create some mock Execution vertices that receive the checkpoint trigger messages
-		final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
-		ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
-
-		// set up the coordinator and validate the initial state
-		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
-			600000,
-			600000,
-			0,
-			Integer.MAX_VALUE,
-			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
-			true,
-			false,
-			0);
-		CheckpointCoordinator coord = new CheckpointCoordinator(
-			new JobID(),
-			chkConfig,
-			new ExecutionVertex[] { vertex1 },
-			new ExecutionVertex[] { vertex1 },
-			new ExecutionVertex[] { vertex1 },
-			new StandaloneCheckpointIDCounter(),
-			new StandaloneCompletedCheckpointStore(1),
-			new MemoryStateBackend(),
-			Executors.directExecutor(),
-			SharedStateRegistry.DEFAULT_FACTORY,
-			failureManager);
-
-		// Periodic
-		try {
-			coord.triggerCheckpoint(
-					System.currentTimeMillis(),
-					CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-					null,
-					true,
-					false);
-			fail("The triggerCheckpoint call expected an exception");
-		} catch (CheckpointException e) {
-			assertEquals(CheckpointFailureReason.PERIODIC_SCHEDULER_SHUTDOWN, e.getCheckpointFailureReason());
-		}
-
-		// Not periodic
-		try {
-			coord.triggerCheckpoint(
-					System.currentTimeMillis(),
-					CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
-					null,
-					false,
-					false);
-		} catch (CheckpointException e) {
-			fail("Unexpected exception : " + e.getCheckpointFailureReason().message());
-		}
-	}
-
 	private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) {
 		List<KeyGroupRange> ranges = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism, parallelism);
 		for (int i = 0; i < maxParallelism; ++i) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
new file mode 100644
index 0000000..b224d9e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
@@ -0,0 +1,308 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
+import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyLong;
+import static org.mockito.Mockito.doAnswer;
+
+/**
+ * Tests for checkpoint coordinator triggering.
+ */
+public class CheckpointCoordinatorTriggeringTest extends TestLogger {
+	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
+
+	private CheckpointFailureManager failureManager;
+
+	@Before
+	public void setUp() throws Exception {
+		failureManager = new CheckpointFailureManager(
+			0,
+			NoOpFailJobCall.INSTANCE);
+	}
+
+	@Test
+	public void testPeriodicTriggering() {
+		try {
+			final JobID jid = new JobID();
+			final long start = System.currentTimeMillis();
+
+			// create some mock execution vertices and trigger some checkpoint
+
+			final ExecutionAttemptID triggerAttemptID = new ExecutionAttemptID();
+			final ExecutionAttemptID ackAttemptID = new ExecutionAttemptID();
+			final ExecutionAttemptID commitAttemptID = new ExecutionAttemptID();
+
+			ExecutionVertex triggerVertex = mockExecutionVertex(triggerAttemptID);
+			ExecutionVertex ackVertex = mockExecutionVertex(ackAttemptID);
+			ExecutionVertex commitVertex = mockExecutionVertex(commitAttemptID);
+
+			final AtomicInteger numCalls = new AtomicInteger();
+
+			final Execution execution = triggerVertex.getCurrentExecutionAttempt();
+
+			doAnswer(new Answer<Void>() {
+
+				private long lastId = -1;
+				private long lastTs = -1;
+
+				@Override
+				public Void answer(InvocationOnMock invocation) throws Throwable {
+					long id = (Long) invocation.getArguments()[0];
+					long ts = (Long) invocation.getArguments()[1];
+
+					assertTrue(id > lastId);
+					assertTrue(ts >= lastTs);
+					assertTrue(ts >= start);
+
+					lastId = id;
+					lastTs = ts;
+					numCalls.incrementAndGet();
+					return null;
+				}
+			}).when(execution).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+
+			CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+				10,        // periodic interval is 10 ms
+				200000,    // timeout is very long (200 s)
+				0,
+				Integer.MAX_VALUE,
+				CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+				true,
+				false,
+				0);
+			CheckpointCoordinator coord = new CheckpointCoordinator(
+				jid,
+				chkConfig,
+				new ExecutionVertex[] { triggerVertex },
+				new ExecutionVertex[] { ackVertex },
+				new ExecutionVertex[] { commitVertex },
+				new StandaloneCheckpointIDCounter(),
+				new StandaloneCompletedCheckpointStore(2),
+				new MemoryStateBackend(),
+				Executors.directExecutor(),
+				SharedStateRegistry.DEFAULT_FACTORY,
+				failureManager);
+
+			coord.startCheckpointScheduler();
+
+			long timeout = System.currentTimeMillis() + 60000;
+			do {
+				Thread.sleep(20);
+			}
+			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
+			assertTrue(numCalls.get() >= 5);
+
+			coord.stopCheckpointScheduler();
+
+			// for 400 ms, no further calls may come.
+			// there may be the case that one trigger was fired and about to
+			// acquire the lock, such that after cancelling it will still do
+			// the remainder of its work
+			int numCallsSoFar = numCalls.get();
+			Thread.sleep(400);
+			assertTrue(numCallsSoFar == numCalls.get() ||
+				numCallsSoFar + 1 == numCalls.get());
+
+			// start another sequence of periodic scheduling
+			numCalls.set(0);
+			coord.startCheckpointScheduler();
+
+			timeout = System.currentTimeMillis() + 60000;
+			do {
+				Thread.sleep(20);
+			}
+			while (timeout > System.currentTimeMillis() && numCalls.get() < 5);
+			assertTrue(numCalls.get() >= 5);
+
+			coord.stopCheckpointScheduler();
+
+			// for 400 ms, no further calls may come
+			// there may be the case that one trigger was fired and about to
+			// acquire the lock, such that after cancelling it will still do
+			// the remainder of its work
+			numCallsSoFar = numCalls.get();
+			Thread.sleep(400);
+			assertTrue(numCallsSoFar == numCalls.get() ||
+				numCallsSoFar + 1 == numCalls.get());
+
+			coord.shutdown(JobStatus.FINISHED);
+		}
+		catch (Exception e) {
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+	}
+
+	/**
+	 * This test verified that after a completed checkpoint a certain time has passed before
+	 * another is triggered.
+	 */
+	@Test
+	public void testMinTimeBetweenCheckpointsInterval() throws Exception {
+		final JobID jid = new JobID();
+
+		// create some mock execution vertices and trigger some checkpoint
+		final ExecutionAttemptID attemptID = new ExecutionAttemptID();
+		final ExecutionVertex vertex = mockExecutionVertex(attemptID);
+		final Execution executionAttempt = vertex.getCurrentExecutionAttempt();
+
+		final BlockingQueue<Long> triggerCalls = new LinkedBlockingQueue<>();
+
+		doAnswer(invocation -> {
+			triggerCalls.add((Long) invocation.getArguments()[0]);
+			return null;
+		}).when(executionAttempt).triggerCheckpoint(anyLong(), anyLong(), any(CheckpointOptions.class));
+
+		final long delay = 50;
+
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			12,           // periodic interval is 12 ms
+			200_000,     // timeout is very long (200 s)
+			delay,       // 50 ms delay between checkpoints
+			1,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		final CheckpointCoordinator coord = new CheckpointCoordinator(
+			jid,
+			chkConfig,
+			new ExecutionVertex[] { vertex },
+			new ExecutionVertex[] { vertex },
+			new ExecutionVertex[] { vertex },
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(2),
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		try {
+			coord.startCheckpointScheduler();
+
+			// wait until the first checkpoint was triggered
+			Long firstCallId = triggerCalls.take();
+			assertEquals(1L, firstCallId.longValue());
+
+			AcknowledgeCheckpoint ackMsg = new AcknowledgeCheckpoint(jid, attemptID, 1L);
+
+			// tell the coordinator that the checkpoint is done
+			final long ackTime = System.nanoTime();
+			coord.receiveAcknowledgeMessage(ackMsg, TASK_MANAGER_LOCATION_INFO);
+
+			// wait until the next checkpoint is triggered
+			Long nextCallId = triggerCalls.take();
+			final long nextCheckpointTime = System.nanoTime();
+			assertEquals(2L, nextCallId.longValue());
+
+			final long delayMillis = (nextCheckpointTime - ackTime) / 1_000_000;
+
+			// we need to add one ms here to account for rounding errors
+			if (delayMillis + 1 < delay) {
+				fail("checkpoint came too early: delay was " + delayMillis + " but should have been at least " + delay);
+			}
+		}
+		finally {
+			coord.stopCheckpointScheduler();
+			coord.shutdown(JobStatus.FINISHED);
+		}
+	}
+
+	@Test
+	public void testStopPeriodicScheduler() throws Exception {
+		// create some mock Execution vertices that receive the checkpoint trigger messages
+		final ExecutionAttemptID attemptID1 = new ExecutionAttemptID();
+		ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
+
+		// set up the coordinator and validate the initial state
+		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
+			600000,
+			600000,
+			0,
+			Integer.MAX_VALUE,
+			CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION,
+			true,
+			false,
+			0);
+		CheckpointCoordinator coord = new CheckpointCoordinator(
+			new JobID(),
+			chkConfig,
+			new ExecutionVertex[] { vertex1 },
+			new ExecutionVertex[] { vertex1 },
+			new ExecutionVertex[] { vertex1 },
+			new StandaloneCheckpointIDCounter(),
+			new StandaloneCompletedCheckpointStore(1),
+			new MemoryStateBackend(),
+			Executors.directExecutor(),
+			SharedStateRegistry.DEFAULT_FACTORY,
+			failureManager);
+
+		// Periodic
+		try {
+			coord.triggerCheckpoint(
+				System.currentTimeMillis(),
+				CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
+				null,
+				true,
+				false);
+			fail("The triggerCheckpoint call expected an exception");
+		} catch (CheckpointException e) {
+			assertEquals(CheckpointFailureReason.PERIODIC_SCHEDULER_SHUTDOWN, e.getCheckpointFailureReason());
+		}
+
+		// Not periodic
+		try {
+			coord.triggerCheckpoint(
+				System.currentTimeMillis(),
+				CheckpointProperties.forCheckpoint(CheckpointRetentionPolicy.NEVER_RETAIN_AFTER_TERMINATION),
+				null,
+				false,
+				false);
+		} catch (CheckpointException e) {
+			fail("Unexpected exception : " + e.getCheckpointFailureReason().message());
+		}
+	}
+
+}


[flink] 01/08: [hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 126cff6ab4e433a69c774e126b2e85335e30ee7a
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 18 17:30:04 2019 +0800

    [hotfix] Extract utils of CheckpointCoordinatorTest into a separate utils class and correct codestyle
---
 .../CheckpointCoordinatorFailureTest.java          |   5 +-
 .../CheckpointCoordinatorMasterHooksTest.java      |   4 +-
 .../checkpoint/CheckpointCoordinatorTest.java      | 517 +++------------------
 .../CheckpointCoordinatorTestingUtils.java         | 472 +++++++++++++++++++
 .../checkpoint/CheckpointStateRestoreTest.java     |   4 +-
 .../runtime/messages/CheckpointMessagesTest.java   |  13 +-
 6 files changed, 542 insertions(+), 473 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index da181ba..b6b7930 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -49,6 +49,9 @@ import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+/**
+ * Tests for failure of checkpoint coordinator.
+ */
 @RunWith(PowerMockRunner.class)
 @PrepareForTest(PendingCheckpoint.class)
 public class CheckpointCoordinatorFailureTest extends TestLogger {
@@ -62,7 +65,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		JobID jid = new JobID();
 
 		final ExecutionAttemptID executionAttemptId = new ExecutionAttemptID();
-		final ExecutionVertex vertex = CheckpointCoordinatorTest.mockExecutionVertex(executionAttemptId);
+		final ExecutionVertex vertex = CheckpointCoordinatorTestingUtils.mockExecutionVertex(executionAttemptId);
 
 		final long triggerTimestamp = 1L;
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index 48b6583..8453cba 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -48,7 +48,7 @@ import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -353,7 +353,7 @@ public class CheckpointCoordinatorMasterHooksTest {
 	// ------------------------------------------------------------------------
 
 	/**
-	 * This test makes sure that the checkpoint is already registered by the time
+	 * This test makes sure that the checkpoint is already registered by the time.
 	 * that the hooks are called
 	 */
 	@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 98d80fd..034fac4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -19,9 +19,7 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.execution.ExecutionState;
@@ -39,7 +37,6 @@ import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
-import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
@@ -55,8 +52,6 @@ import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLo
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.util.ExceptionUtils;
-import org.apache.flink.util.InstantiationUtil;
-import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializableObject;
 import org.apache.flink.util.TestLogger;
 
@@ -76,7 +71,6 @@ import org.mockito.stubbing.Answer;
 import org.mockito.verification.VerificationMode;
 
 import java.io.IOException;
-import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -91,11 +85,20 @@ import java.util.UUID;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Executor;
 import java.util.concurrent.LinkedBlockingQueue;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.compareKeyedState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.comparePartitionableState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateChainedPartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generateKeyGroupState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecution;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionJobVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockExecutionVertex;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.mockSubtaskState;
+import static org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.verifyStateRestore;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotEquals;
@@ -439,7 +442,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
-
 			// decline checkpoint from the other task, this should cancel the checkpoint
 			// and trigger a new one
 			coord.receiveDeclineMessage(new DeclineCheckpoint(jid, attemptID1, checkpointId), TASK_MANAGER_LOCATION_INFO);
@@ -923,19 +925,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
 			OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
 
-			TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates11 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates12 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates13 = spy(new TaskStateSnapshot());
 
-			OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class);
-			taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1);
-			taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2);
-			taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3);
+			OperatorSubtaskState subtaskState11 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState12 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState13 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates11.putSubtaskStateByOperatorID(opID1, subtaskState11);
+			taskOperatorSubtaskStates12.putSubtaskStateByOperatorID(opID2, subtaskState12);
+			taskOperatorSubtaskStates13.putSubtaskStateByOperatorID(opID3, subtaskState13);
 
 			// acknowledge one of the three tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates12), TASK_MANAGER_LOCATION_INFO);
 
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
@@ -953,17 +955,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			}
 			long checkpointId2 = pending2.getCheckpointId();
 
-			TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot());
-			TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates21 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates22 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates23 = spy(new TaskStateSnapshot());
 
-			OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class);
-			OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState21 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState22 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState23 = mock(OperatorSubtaskState.class);
 
-			taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1);
-			taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2);
-			taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3);
+			taskOperatorSubtaskStates21.putSubtaskStateByOperatorID(opID1, subtaskState21);
+			taskOperatorSubtaskStates22.putSubtaskStateByOperatorID(opID2, subtaskState22);
+			taskOperatorSubtaskStates23.putSubtaskStateByOperatorID(opID3, subtaskState23);
 
 			// trigger messages should have been sent
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -972,13 +974,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// we acknowledge one more task from the first checkpoint and the second
 			// checkpoint completely. The second checkpoint should then subsume the first checkpoint
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates23), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates21), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates11), TASK_MANAGER_LOCATION_INFO);
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2), TASK_MANAGER_LOCATION_INFO);
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates22), TASK_MANAGER_LOCATION_INFO);
 
 			// now, the second checkpoint should be confirmed, and the first discarded
 			// actually both pending checkpoints are discarded, and the second has been transformed
@@ -990,13 +992,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that all received subtask states in the first checkpoint have been discarded
-			verify(subtaskState1_1, times(1)).discardState();
-			verify(subtaskState1_2, times(1)).discardState();
+			verify(subtaskState11, times(1)).discardState();
+			verify(subtaskState12, times(1)).discardState();
 
 			// validate that all subtask states in the second checkpoint are not discarded
-			verify(subtaskState2_1, never()).discardState();
-			verify(subtaskState2_2, never()).discardState();
-			verify(subtaskState2_3, never()).discardState();
+			verify(subtaskState21, never()).discardState();
+			verify(subtaskState22, never()).discardState();
+			verify(subtaskState23, never()).discardState();
 
 			// validate the committed checkpoints
 			List<CompletedCheckpoint> scs = coord.getSuccessfulCheckpoints();
@@ -1010,15 +1012,15 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
 			// send the last remaining ack for the first checkpoint. This should not do anything
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3), TASK_MANAGER_LOCATION_INFO);
-			verify(subtaskState1_3, times(1)).discardState();
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates13), TASK_MANAGER_LOCATION_INFO);
+			verify(subtaskState13, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
 
 			// validate that the states in the second checkpoint have been discarded
-			verify(subtaskState2_1, times(1)).discardState();
-			verify(subtaskState2_2, times(1)).discardState();
-			verify(subtaskState2_3, times(1)).discardState();
+			verify(subtaskState21, times(1)).discardState();
+			verify(subtaskState22, times(1)).discardState();
+			verify(subtaskState23, times(1)).discardState();
 
 		}
 		catch (Exception e) {
@@ -1377,7 +1379,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			coord.stopCheckpointScheduler();
 
-
 			// for 400 ms, no further calls may come.
 			// there may be the case that one trigger was fired and about to
 			// acquire the lock, such that after cancelling it will still do
@@ -1385,7 +1386,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			int numCallsSoFar = numCalls.get();
 			Thread.sleep(400);
 			assertTrue(numCallsSoFar == numCalls.get() ||
-					numCallsSoFar+1 == numCalls.get());
+					numCallsSoFar + 1 == numCalls.get());
 
 			// start another sequence of periodic scheduling
 			numCalls.set(0);
@@ -2200,7 +2201,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 	 *
 	 * @throws Exception
 	 */
-	@Test(expected=IllegalStateException.class)
+	@Test(expected = IllegalStateException.class)
 	public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception {
 		final JobID jid = new JobID();
 		final long timestamp = System.currentTimeMillis();
@@ -2275,7 +2276,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
 		}
 
-
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
 			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, null, keyGroupState, null);
@@ -2391,9 +2391,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			long checkpointId = checkpointIDCounter.getLast();
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			KeyedStateHandle serializedKeyGroupStates = generateKeyGroupState(keyGroupRange, testStates);
 
 			TaskStateSnapshot subtaskStatesForCheckpoint = new TaskStateSnapshot();
 
@@ -2416,10 +2416,9 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			timestamp = System.currentTimeMillis();
 			CompletableFuture<CompletedCheckpoint> savepointFuture = coord.triggerSavepoint(timestamp, savepointDir);
 
-
 			KeyGroupRange keyGroupRangeForSavepoint = KeyGroupRange.of(1, 1);
 			List<SerializableObject> testStatesForSavepoint = Collections.singletonList(new SerializableObject());
-			KeyedStateHandle serializedKeyGroupStatesForSavepoint = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
+			KeyedStateHandle serializedKeyGroupStatesForSavepoint = generateKeyGroupState(keyGroupRangeForSavepoint, testStatesForSavepoint);
 
 			TaskStateSnapshot subtaskStatesForSavepoint = new TaskStateSnapshot();
 
@@ -2677,16 +2676,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		OperatorID operatorID = OperatorID.fromJobVertexID(jobVertexID);
 		return new Tuple2<>(jobVertexID, operatorID);
 	}
-	
+
 	/**
-	 * old topology
+	 * <p>
+	 * old topology.
 	 * [operator1,operator2] * parallelism1 -> [operator3,operator4] * parallelism2
+	 * </p>
 	 *
-	 *
+	 * <p>
 	 * new topology
 	 *
 	 * [operator5,operator1,operator3] * newParallelism1 -> [operator3, operator6] * newParallelism2
-	 *
+	 * </p>
 	 * scaleType:
 	 * 0  increase parallelism
 	 * 1  decrease parallelism
@@ -2956,7 +2957,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
 			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
 
-
 			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
 			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
 		}
@@ -3019,383 +3019,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		}
 	}
 
-	// ------------------------------------------------------------------------
-	//  Utilities
-	// ------------------------------------------------------------------------
-
-	public static KeyGroupsStateHandle generateKeyGroupState(
-			JobVertexID jobVertexID,
-			KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
-
-		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
-
-		// generate state for one keygroup
-		for (int keyGroupIndex : keyGroupPartition) {
-			int vertexHash = jobVertexID.hashCode();
-			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
-			Random random = new Random(seed);
-			int simulatedStateValue = random.nextInt();
-			testStatesLists.add(simulatedStateValue);
-		}
-
-		return generateKeyGroupState(keyGroupPartition, testStatesLists);
-	}
-
-	public static KeyGroupsStateHandle generateKeyGroupState(
-			KeyGroupRange keyGroupRange,
-			List<? extends Serializable> states) throws IOException {
-
-		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
-
-		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
-				serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
-
-		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
-
-		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
-				String.valueOf(UUID.randomUUID()),
-				serializedDataWithOffsets.f0);
-
-		return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
-	}
-
-	public static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
-			List<List<? extends Serializable>> serializables) throws IOException {
-
-		List<long[]> offsets = new ArrayList<>(serializables.size());
-		List<byte[]> serializedGroupValues = new ArrayList<>();
-
-		int runningGroupsOffset = 0;
-		for(List<? extends Serializable> list : serializables) {
-
-			long[] currentOffsets = new long[list.size()];
-			offsets.add(currentOffsets);
-
-			for (int i = 0; i < list.size(); ++i) {
-				currentOffsets[i] = runningGroupsOffset;
-				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
-				serializedGroupValues.add(serializedValue);
-				runningGroupsOffset += serializedValue.length;
-			}
-		}
-
-		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
-		byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
-		runningGroupsOffset = 0;
-		for (byte[] serializedGroupValue : serializedGroupValues) {
-			System.arraycopy(
-					serializedGroupValue,
-					0,
-					allSerializedValuesConcatenated,
-					runningGroupsOffset,
-					serializedGroupValue.length);
-			runningGroupsOffset += serializedGroupValue.length;
-		}
-		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
-	}
-
-	public static OperatorStateHandle generatePartitionableStateHandle(
-		JobVertexID jobVertexID,
-		int index,
-		int namedStates,
-		int partitionsPerState,
-		boolean rawState) throws IOException {
-
-		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
-		for (int i = 0; i < namedStates; ++i) {
-			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
-			// generate state
-			int seed = jobVertexID.hashCode() * index + i * namedStates;
-			if (rawState) {
-				seed = (seed + 1) * 31;
-			}
-			Random random = new Random(seed);
-			for (int j = 0; j < partitionsPerState; ++j) {
-				int simulatedStateValue = random.nextInt();
-				testStatesLists.add(simulatedStateValue);
-			}
-			statesListsMap.put("state-" + i, testStatesLists);
-		}
-
-		return generatePartitionableStateHandle(statesListsMap);
-	}
-
-	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
-			JobVertexID jobVertexID,
-			int index,
-			int namedStates,
-			int partitionsPerState,
-			boolean rawState) throws IOException {
-
-		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
-
-		for (int i = 0; i < namedStates; ++i) {
-			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
-			// generate state
-			int seed = jobVertexID.hashCode() * index + i * namedStates;
-			if (rawState) {
-				seed = (seed + 1) * 31;
-			}
-			Random random = new Random(seed);
-			for (int j = 0; j < partitionsPerState; ++j) {
-				int simulatedStateValue = random.nextInt();
-				testStatesLists.add(simulatedStateValue);
-			}
-			statesListsMap.put("state-" + i, testStatesLists);
-		}
-
-		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
-	}
-
-	private static OperatorStateHandle generatePartitionableStateHandle(
-		Map<String, List<? extends Serializable>> states) throws IOException {
-
-		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
-
-		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
-			namedStateSerializables.add(entry.getValue());
-		}
-
-		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
-
-		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
-
-		int idx = 0;
-		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
-			offsetsMap.put(
-				entry.getKey(),
-				new OperatorStateHandle.StateMetaInfo(
-					serializationWithOffsets.f1.get(idx),
-					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
-			++idx;
-		}
-
-		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
-			String.valueOf(UUID.randomUUID()),
-			serializationWithOffsets.f0);
-
-		return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
-	}
-
-	static ExecutionJobVertex mockExecutionJobVertex(
-			JobVertexID jobVertexID,
-			int parallelism,
-			int maxParallelism) {
-
-		return mockExecutionJobVertex(
-			jobVertexID,
-			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-			parallelism,
-			maxParallelism
-		);
-	}
-
-	static ExecutionJobVertex mockExecutionJobVertex(
-		JobVertexID jobVertexID,
-		List<OperatorID> jobVertexIDs,
-		int parallelism,
-		int maxParallelism) {
-		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
-
-		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
-
-		for (int i = 0; i < parallelism; i++) {
-			executionVertices[i] = mockExecutionVertex(
-				new ExecutionAttemptID(),
-				jobVertexID,
-				jobVertexIDs,
-				parallelism,
-				maxParallelism,
-				ExecutionState.RUNNING);
-
-			when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
-		}
-
-		when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
-		when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
-		when(executionJobVertex.getParallelism()).thenReturn(parallelism);
-		when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
-		when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
-		when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
-		when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
-
-		return executionJobVertex;
-	}
-
-	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
-		JobVertexID jobVertexID = new JobVertexID();
-		return mockExecutionVertex(
-			attemptID,
-			jobVertexID,
-			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
-			1,
-			1,
-			ExecutionState.RUNNING);
-	}
-
-	private static ExecutionVertex mockExecutionVertex(
-		ExecutionAttemptID attemptID,
-		JobVertexID jobVertexID,
-		List<OperatorID> jobVertexIDs,
-		int parallelism,
-		int maxParallelism,
-		ExecutionState state,
-		ExecutionState ... successiveStates) {
-
-		ExecutionVertex vertex = mock(ExecutionVertex.class);
-
-		final Execution exec = spy(new Execution(
-			mock(Executor.class),
-			vertex,
-			1,
-			1L,
-			1L,
-			Time.milliseconds(500L)
-		));
-		when(exec.getAttemptId()).thenReturn(attemptID);
-		when(exec.getState()).thenReturn(state, successiveStates);
-
-		when(vertex.getJobvertexId()).thenReturn(jobVertexID);
-		when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
-		when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
-
-		ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
-		when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
-		
-		when(vertex.getJobVertex()).thenReturn(jobVertex);
-
-		return vertex;
-	}
-
-	static TaskStateSnapshot mockSubtaskState(
-		JobVertexID jobVertexID,
-		int index,
-		KeyGroupRange keyGroupRange) throws IOException {
-
-		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
-		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
-
-		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
-		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
-			partitionableState, null, partitionedKeyGroupState, null)
-		);
-
-		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
-
-		return subtaskStates;
-	}
-
-	public static void verifyStateRestore(
-			JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
-			List<KeyGroupRange> keyGroupPartitions) throws Exception {
-
-		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
-
-			JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
-			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
-			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
-
-			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
-
-			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
-					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
-
-			assertTrue(CommonTestUtils.isStreamContentEqual(
-					expectedOpStateBackend.get(0).openInputStream(),
-					operatorState.getManagedOperatorState().iterator().next().openInputStream()));
-
-			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
-					jobVertexID, keyGroupPartitions.get(i), false);
-			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
-		}
-	}
-
-	public static void compareKeyedState(
-			Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
-			Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
-
-		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
-		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
-		int actualTotalKeyGroups = 0;
-		for(KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
-			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
-
-			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
-		}
-
-		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
-
-		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
-			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
-				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-				inputStream.seek(offset);
-				int expectedKeyGroupState =
-						InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
-				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
-
-					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
-
-					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
-					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
-						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
-						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
-							actualInputStream.seek(actualOffset);
-							int actualGroupState = InstantiationUtil.
-									deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
-							assertEquals(expectedKeyGroupState, actualGroupState);
-						}
-					}
-				}
-			}
-		}
-	}
-
-	public static void comparePartitionableState(
-			List<ChainedStateHandle<OperatorStateHandle>> expected,
-			List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
-
-		List<String> expectedResult = new ArrayList<>();
-		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
-			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
-				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
-				collectResult(i, operatorStateHandle, expectedResult);
-			}
-		}
-		Collections.sort(expectedResult);
-
-		List<String> actualResult = new ArrayList<>();
-		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
-			if (collectionList != null) {
-				for (int i = 0; i < collectionList.size(); ++i) {
-					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
-					Assert.assertNotNull(stateHandles);
-					for (OperatorStateHandle operatorStateHandle : stateHandles) {
-						collectResult(i, operatorStateHandle, actualResult);
-					}
-				}
-			}
-		}
-
-		Collections.sort(actualResult);
-		Assert.assertEquals(expectedResult, actualResult);
-	}
-
-	private static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
-		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
-			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
-				for (long offset : entry.getValue().getOffsets()) {
-					in.seek(offset);
-					Integer state = InstantiationUtil.
-							deserializeObject(in, Thread.currentThread().getContextClassLoader());
-					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
-				}
-			}
-		}
-	}
-
-
 	@Test
 	public void testCreateKeyGroupPartitions() {
 		testCreateKeyGroupPartitions(1, 1);
@@ -4071,36 +3694,4 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint, TASK_MANAGER_LOCATION_INFO);
 		}
 	}
-
-	private Execution mockExecution() {
-		Execution mock = mock(Execution.class);
-		when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
-		when(mock.getState()).thenReturn(ExecutionState.RUNNING);
-		return mock;
-	}
-
-	private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
-		ExecutionVertex mock = mock(ExecutionVertex.class);
-		when(mock.getJobvertexId()).thenReturn(vertexId);
-		when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
-		when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
-		when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
-		when(mock.getMaxParallelism()).thenReturn(parallelism);
-		return mock;
-	}
-
-	private ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
-		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
-		when(vertex.getParallelism()).thenReturn(vertices.length);
-		when(vertex.getMaxParallelism()).thenReturn(vertices.length);
-		when(vertex.getJobVertexId()).thenReturn(id);
-		when(vertex.getTaskVertices()).thenReturn(vertices);
-		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
-		when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
-
-		for (ExecutionVertex v : vertices) {
-			when(v.getJobVertex()).thenReturn(vertex);
-		}
-		return vertex;
-	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
new file mode 100644
index 0000000..9578ac7
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.api.common.time.Time;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.ChainedStateHandle;
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
+import org.apache.flink.runtime.state.KeyGroupsStateHandle;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.util.InstantiationUtil;
+import org.apache.flink.util.Preconditions;
+
+import org.junit.Assert;
+
+import java.io.IOException;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+import java.util.concurrent.Executor;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.when;
+
+/**
+ * Testing utils for checkpoint coordinator.
+ */
+public class CheckpointCoordinatorTestingUtils {
+
+	public static OperatorStateHandle generatePartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generatePartitionableStateHandle(statesListsMap);
+	}
+
+	static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
+	}
+
+	static OperatorStateHandle generatePartitionableStateHandle(
+		Map<String, List<? extends Serializable>> states) throws IOException {
+
+		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
+
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			namedStateSerializables.add(entry.getValue());
+		}
+
+		Tuple2<byte[], List<long[]>> serializationWithOffsets = serializeTogetherAndTrackOffsets(namedStateSerializables);
+
+		Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>(states.size());
+
+		int idx = 0;
+		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
+			offsetsMap.put(
+				entry.getKey(),
+				new OperatorStateHandle.StateMetaInfo(
+					serializationWithOffsets.f1.get(idx),
+					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+			++idx;
+		}
+
+		ByteStreamStateHandle streamStateHandle = new ByteStreamStateHandle(
+			String.valueOf(UUID.randomUUID()),
+			serializationWithOffsets.f0);
+
+		return new OperatorStreamStateHandle(offsetsMap, streamStateHandle);
+	}
+
+	static Tuple2<byte[], List<long[]>> serializeTogetherAndTrackOffsets(
+		List<List<? extends Serializable>> serializables) throws IOException {
+
+		List<long[]> offsets = new ArrayList<>(serializables.size());
+		List<byte[]> serializedGroupValues = new ArrayList<>();
+
+		int runningGroupsOffset = 0;
+		for (List<? extends Serializable> list : serializables) {
+
+			long[] currentOffsets = new long[list.size()];
+			offsets.add(currentOffsets);
+
+			for (int i = 0; i < list.size(); ++i) {
+				currentOffsets[i] = runningGroupsOffset;
+				byte[] serializedValue = InstantiationUtil.serializeObject(list.get(i));
+				serializedGroupValues.add(serializedValue);
+				runningGroupsOffset += serializedValue.length;
+			}
+		}
+
+		//write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray
+		byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset];
+		runningGroupsOffset = 0;
+		for (byte[] serializedGroupValue : serializedGroupValues) {
+			System.arraycopy(
+				serializedGroupValue,
+				0,
+				allSerializedValuesConcatenated,
+				runningGroupsOffset,
+				serializedGroupValue.length);
+			runningGroupsOffset += serializedGroupValue.length;
+		}
+		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
+	}
+
+	public static void verifyStateRestore(
+		JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex,
+		List<KeyGroupRange> keyGroupPartitions) throws Exception {
+
+		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
+
+			JobManagerTaskRestore taskRestore = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskRestore();
+			Assert.assertEquals(1L, taskRestore.getRestoreCheckpointId());
+			TaskStateSnapshot stateSnapshot = taskRestore.getTaskStateSnapshot();
+
+			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
+				generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
+
+			assertTrue(CommonTestUtils.isStreamContentEqual(
+				expectedOpStateBackend.get(0).openInputStream(),
+				operatorState.getManagedOperatorState().iterator().next().openInputStream()));
+
+			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
+				jobVertexID, keyGroupPartitions.get(i), false);
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
+		}
+	}
+
+	static void compareKeyedState(
+		Collection<KeyGroupsStateHandle> expectPartitionedKeyGroupState,
+		Collection<? extends KeyedStateHandle> actualPartitionedKeyGroupState) throws Exception {
+
+		KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.iterator().next();
+		int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+		int actualTotalKeyGroups = 0;
+		for (KeyedStateHandle keyedStateHandle: actualPartitionedKeyGroupState) {
+			assertTrue(keyedStateHandle instanceof KeyGroupsStateHandle);
+
+			actualTotalKeyGroups += keyedStateHandle.getKeyGroupRange().getNumberOfKeyGroups();
+		}
+
+		assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups);
+
+		try (FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.openInputStream()) {
+			for (int groupId : expectedHeadOpKeyGroupStateHandle.getKeyGroupRange()) {
+				long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+				inputStream.seek(offset);
+				int expectedKeyGroupState =
+					InstantiationUtil.deserializeObject(inputStream, Thread.currentThread().getContextClassLoader());
+				for (KeyedStateHandle oneActualKeyedStateHandle : actualPartitionedKeyGroupState) {
+
+					assertTrue(oneActualKeyedStateHandle instanceof KeyGroupsStateHandle);
+
+					KeyGroupsStateHandle oneActualKeyGroupStateHandle = (KeyGroupsStateHandle) oneActualKeyedStateHandle;
+					if (oneActualKeyGroupStateHandle.getKeyGroupRange().contains(groupId)) {
+						long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId);
+						try (FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.openInputStream()) {
+							actualInputStream.seek(actualOffset);
+							int actualGroupState = InstantiationUtil.
+								deserializeObject(actualInputStream, Thread.currentThread().getContextClassLoader());
+							assertEquals(expectedKeyGroupState, actualGroupState);
+						}
+					}
+				}
+			}
+		}
+	}
+
+	static void comparePartitionableState(
+		List<ChainedStateHandle<OperatorStateHandle>> expected,
+		List<List<Collection<OperatorStateHandle>>> actual) throws Exception {
+
+		List<String> expectedResult = new ArrayList<>();
+		for (ChainedStateHandle<OperatorStateHandle> chainedStateHandle : expected) {
+			for (int i = 0; i < chainedStateHandle.getLength(); ++i) {
+				OperatorStateHandle operatorStateHandle = chainedStateHandle.get(i);
+				collectResult(i, operatorStateHandle, expectedResult);
+			}
+		}
+		Collections.sort(expectedResult);
+
+		List<String> actualResult = new ArrayList<>();
+		for (List<Collection<OperatorStateHandle>> collectionList : actual) {
+			if (collectionList != null) {
+				for (int i = 0; i < collectionList.size(); ++i) {
+					Collection<OperatorStateHandle> stateHandles = collectionList.get(i);
+					Assert.assertNotNull(stateHandles);
+					for (OperatorStateHandle operatorStateHandle : stateHandles) {
+						collectResult(i, operatorStateHandle, actualResult);
+					}
+				}
+			}
+		}
+
+		Collections.sort(actualResult);
+		Assert.assertEquals(expectedResult, actualResult);
+	}
+
+	static void collectResult(int opIdx, OperatorStateHandle operatorStateHandle, List<String> resultCollector) throws Exception {
+		try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
+			for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : operatorStateHandle.getStateNameToPartitionOffsets().entrySet()) {
+				for (long offset : entry.getValue().getOffsets()) {
+					in.seek(offset);
+					Integer state = InstantiationUtil.
+						deserializeObject(in, Thread.currentThread().getContextClassLoader());
+					resultCollector.add(opIdx + " : " + entry.getKey() + " : " + state);
+				}
+			}
+		}
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(
+		JobVertexID jobVertexID,
+		int parallelism,
+		int maxParallelism) {
+
+		return mockExecutionJobVertex(
+			jobVertexID,
+			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+			parallelism,
+			maxParallelism
+		);
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(
+		JobVertexID jobVertexID,
+		List<OperatorID> jobVertexIDs,
+		int parallelism,
+		int maxParallelism) {
+		final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class);
+
+		ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism];
+
+		for (int i = 0; i < parallelism; i++) {
+			executionVertices[i] = mockExecutionVertex(
+				new ExecutionAttemptID(),
+				jobVertexID,
+				jobVertexIDs,
+				parallelism,
+				maxParallelism,
+				ExecutionState.RUNNING);
+
+			when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i);
+		}
+
+		when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
+		when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices);
+		when(executionJobVertex.getParallelism()).thenReturn(parallelism);
+		when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism);
+		when(executionJobVertex.isMaxParallelismConfigured()).thenReturn(true);
+		when(executionJobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+		when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Arrays.asList(new OperatorID[jobVertexIDs.size()]));
+
+		return executionJobVertex;
+	}
+
+	static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) {
+		JobVertexID jobVertexID = new JobVertexID();
+		return mockExecutionVertex(
+			attemptID,
+			jobVertexID,
+			Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)),
+			1,
+			1,
+			ExecutionState.RUNNING);
+	}
+
+	static ExecutionVertex mockExecutionVertex(
+		ExecutionAttemptID attemptID,
+		JobVertexID jobVertexID,
+		List<OperatorID> jobVertexIDs,
+		int parallelism,
+		int maxParallelism,
+		ExecutionState state,
+		ExecutionState ... successiveStates) {
+
+		ExecutionVertex vertex = mock(ExecutionVertex.class);
+
+		final Execution exec = spy(new Execution(
+			mock(Executor.class),
+			vertex,
+			1,
+			1L,
+			1L,
+			Time.milliseconds(500L)
+		));
+		when(exec.getAttemptId()).thenReturn(attemptID);
+		when(exec.getState()).thenReturn(state, successiveStates);
+
+		when(vertex.getJobvertexId()).thenReturn(jobVertexID);
+		when(vertex.getCurrentExecutionAttempt()).thenReturn(exec);
+		when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+		when(vertex.getMaxParallelism()).thenReturn(maxParallelism);
+
+		ExecutionJobVertex jobVertex = mock(ExecutionJobVertex.class);
+		when(jobVertex.getOperatorIDs()).thenReturn(jobVertexIDs);
+
+		when(vertex.getJobVertex()).thenReturn(jobVertex);
+
+		return vertex;
+	}
+
+	static TaskStateSnapshot mockSubtaskState(
+		JobVertexID jobVertexID,
+		int index,
+		KeyGroupRange keyGroupRange) throws IOException {
+
+		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
+		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
+
+		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
+			partitionableState, null, partitionedKeyGroupState, null)
+		);
+
+		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
+
+		return subtaskStates;
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
+		JobVertexID jobVertexID,
+		KeyGroupRange keyGroupPartition, boolean rawState) throws IOException {
+
+		List<Integer> testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups());
+
+		// generate state for one keygroup
+		for (int keyGroupIndex : keyGroupPartition) {
+			int vertexHash = jobVertexID.hashCode();
+			int seed = rawState ? (vertexHash * (31 + keyGroupIndex)) : (vertexHash + keyGroupIndex);
+			Random random = new Random(seed);
+			int simulatedStateValue = random.nextInt();
+			testStatesLists.add(simulatedStateValue);
+		}
+
+		return generateKeyGroupState(keyGroupPartition, testStatesLists);
+	}
+
+	public static KeyGroupsStateHandle generateKeyGroupState(
+		KeyGroupRange keyGroupRange,
+		List<? extends Serializable> states) throws IOException {
+
+		Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size());
+
+		Tuple2<byte[], List<long[]>> serializedDataWithOffsets =
+			serializeTogetherAndTrackOffsets(Collections.<List<? extends Serializable>>singletonList(states));
+
+		KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, serializedDataWithOffsets.f1.get(0));
+
+		ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle(
+			String.valueOf(UUID.randomUUID()),
+			serializedDataWithOffsets.f0);
+
+		return new KeyGroupsStateHandle(keyGroupRangeOffsets, allSerializedStatesHandle);
+	}
+
+	static Execution mockExecution() {
+		Execution mock = mock(Execution.class);
+		when(mock.getAttemptId()).thenReturn(new ExecutionAttemptID());
+		when(mock.getState()).thenReturn(ExecutionState.RUNNING);
+		return mock;
+	}
+
+	static ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID vertexId, int subtask, int parallelism) {
+		ExecutionVertex mock = mock(ExecutionVertex.class);
+		when(mock.getJobvertexId()).thenReturn(vertexId);
+		when(mock.getParallelSubtaskIndex()).thenReturn(subtask);
+		when(mock.getCurrentExecutionAttempt()).thenReturn(execution);
+		when(mock.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism);
+		when(mock.getMaxParallelism()).thenReturn(parallelism);
+		return mock;
+	}
+
+	static ExecutionJobVertex mockExecutionJobVertex(JobVertexID id, ExecutionVertex[] vertices) {
+		ExecutionJobVertex vertex = mock(ExecutionJobVertex.class);
+		when(vertex.getParallelism()).thenReturn(vertices.length);
+		when(vertex.getMaxParallelism()).thenReturn(vertices.length);
+		when(vertex.getJobVertexId()).thenReturn(id);
+		when(vertex.getTaskVertices()).thenReturn(vertices);
+		when(vertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(id)));
+		when(vertex.getUserDefinedOperatorIDs()).thenReturn(Collections.<OperatorID>singletonList(null));
+
+		for (ExecutionVertex v : vertices) {
+			when(v.getJobVertex()).thenReturn(vertex);
+		}
+		return vertex;
+	}
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index af2e5a3..42849b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -79,9 +79,9 @@ public class CheckpointStateRestoreTest {
 	public void testSetState() {
 		try {
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(0,0);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(0, 0);
 			List<SerializableObject> testStates = Collections.singletonList(new SerializableObject());
-			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, testStates);
+			final KeyedStateHandle serializedKeyGroupStates = CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, testStates);
 
 			final JobID jid = new JobID();
 			final JobVertexID statefulId = new JobVertexID();
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index f18b4c8..0fb46e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.messages;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.testutils.CommonTestUtils;
-import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
+import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
@@ -45,6 +45,9 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
 
+/**
+ * Tests for checkpoint messages.
+ */
 public class CheckpointMessagesTest {
 
 	@Test
@@ -69,15 +72,15 @@ public class CheckpointMessagesTest {
 			AcknowledgeCheckpoint noState = new AcknowledgeCheckpoint(
 					new JobID(), new ExecutionAttemptID(), 569345L);
 
-			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
+			KeyGroupRange keyGroupRange = KeyGroupRange.of(42, 42);
 
 			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
 			checkpointStateHandles.putSubtaskStateByOperatorID(
 				new OperatorID(),
 				new OperatorSubtaskState(
-					CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+					CheckpointCoordinatorTestingUtils.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
 					null,
-					CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+					CheckpointCoordinatorTestingUtils.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
 					null
 				)
 			);
@@ -105,7 +108,7 @@ public class CheckpointMessagesTest {
 		assertNotNull(copy.toString());
 	}
 
-	public static class MyHandle implements StreamStateHandle {
+	private static class MyHandle implements StreamStateHandle {
 
 		private static final long serialVersionUID = 8128146204128728332L;
 


[flink] 08/08: [FLINK-13904][checkpointing] Encapsule and optimize the time relevant operation of CheckpointCoordinator

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit beb3fb06bdca64c4732318667ab59ce298da3b97
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Fri Oct 11 16:40:27 2019 +0800

    [FLINK-13904][checkpointing] Encapsule and optimize the time relevant operation of CheckpointCoordinator
---
 .../runtime/checkpoint/CheckpointCoordinator.java  | 61 ++++++++++++++++++----
 1 file changed, 50 insertions(+), 11 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index 7517a38..6d34d8c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -42,6 +42,8 @@ import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SharedStateRegistryFactory;
 import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.util.clock.Clock;
+import org.apache.flink.runtime.util.clock.SystemClock;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StringUtils;
@@ -130,9 +132,9 @@ public class CheckpointCoordinator {
 	/** The max time (in ms) that a checkpoint may take. */
 	private final long checkpointTimeout;
 
-	/** The min time(in ns) to delay after a checkpoint could be triggered. Allows to
+	/** The min time(in ms) to delay after a checkpoint could be triggered. Allows to
 	 * enforce minimum processing time between checkpoint attempts */
-	private final long minPauseBetweenCheckpointsNanos;
+	private final long minPauseBetweenCheckpoints;
 
 	/** The maximum number of checkpoints that may be in progress at the same time. */
 	private final int maxConcurrentCheckpointAttempts;
@@ -153,8 +155,9 @@ public class CheckpointCoordinator {
 	/** A handle to the current periodic trigger, to cancel it when necessary. */
 	private ScheduledFuture<?> currentPeriodicTrigger;
 
-	/** The timestamp (via {@link System#nanoTime()}) when the last checkpoint completed. */
-	private long lastCheckpointCompletionNanos;
+	/** The timestamp (via {@link Clock#relativeTimeMillis()}) when the last checkpoint
+	 * completed. */
+	private long lastCheckpointCompletionRelativeTime;
 
 	/** Flag whether a triggered checkpoint should immediately schedule the next checkpoint.
 	 * Non-volatile, because only accessed in synchronized scope */
@@ -181,9 +184,42 @@ public class CheckpointCoordinator {
 
 	private final CheckpointFailureManager failureManager;
 
+	private final Clock clock;
+
 	// --------------------------------------------------------------------------------------------
 
 	public CheckpointCoordinator(
+		JobID job,
+		CheckpointCoordinatorConfiguration chkConfig,
+		ExecutionVertex[] tasksToTrigger,
+		ExecutionVertex[] tasksToWaitFor,
+		ExecutionVertex[] tasksToCommitTo,
+		CheckpointIDCounter checkpointIDCounter,
+		CompletedCheckpointStore completedCheckpointStore,
+		StateBackend checkpointStateBackend,
+		Executor executor,
+		ScheduledExecutor timer,
+		SharedStateRegistryFactory sharedStateRegistryFactory,
+		CheckpointFailureManager failureManager) {
+
+		this(
+			job,
+			chkConfig,
+			tasksToTrigger,
+			tasksToWaitFor,
+			tasksToCommitTo,
+			checkpointIDCounter,
+			completedCheckpointStore,
+			checkpointStateBackend,
+			executor,
+			timer,
+			sharedStateRegistryFactory,
+			failureManager,
+			SystemClock.getInstance());
+	}
+
+	@VisibleForTesting
+	public CheckpointCoordinator(
 			JobID job,
 			CheckpointCoordinatorConfiguration chkConfig,
 			ExecutionVertex[] tasksToTrigger,
@@ -195,7 +231,8 @@ public class CheckpointCoordinator {
 			Executor executor,
 			ScheduledExecutor timer,
 			SharedStateRegistryFactory sharedStateRegistryFactory,
-			CheckpointFailureManager failureManager) {
+			CheckpointFailureManager failureManager,
+			Clock clock) {
 
 		// sanity checks
 		checkNotNull(checkpointStateBackend);
@@ -216,7 +253,7 @@ public class CheckpointCoordinator {
 		this.job = checkNotNull(job);
 		this.baseInterval = baseInterval;
 		this.checkpointTimeout = chkConfig.getCheckpointTimeout();
-		this.minPauseBetweenCheckpointsNanos = minPauseBetweenCheckpoints * 1_000_000;
+		this.minPauseBetweenCheckpoints = minPauseBetweenCheckpoints;
 		this.maxConcurrentCheckpointAttempts = chkConfig.getMaxConcurrentCheckpoints();
 		this.tasksToTrigger = checkNotNull(tasksToTrigger);
 		this.tasksToWaitFor = checkNotNull(tasksToWaitFor);
@@ -229,6 +266,7 @@ public class CheckpointCoordinator {
 		this.sharedStateRegistry = sharedStateRegistryFactory.create(executor);
 		this.isPreferCheckpointForRecovery = chkConfig.isPreferCheckpointForRecovery();
 		this.failureManager = checkNotNull(failureManager);
+		this.clock = checkNotNull(clock);
 
 		this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 		this.masterHooks = new HashMap<>();
@@ -890,7 +928,7 @@ public class CheckpointCoordinator {
 
 		// record the time when this was completed, to calculate
 		// the 'min delay between checkpoints'
-		lastCheckpointCompletionNanos = System.nanoTime();
+		lastCheckpointCompletionRelativeTime = clock.relativeTimeMillis();
 
 		LOG.info("Completed checkpoint {} for job {} ({} bytes in {} ms).", checkpointId, job,
 			completedCheckpoint.getStateSize(), completedCheckpoint.getDuration());
@@ -1253,8 +1291,10 @@ public class CheckpointCoordinator {
 	 * @throws CheckpointException If the minimum interval between checkpoints has not passed.
 	 */
 	private void checkMinPauseBetweenCheckpoints() throws CheckpointException {
-		final long earliestNext = lastCheckpointCompletionNanos + minPauseBetweenCheckpointsNanos;
-		final long durationTillNextMillis = (earliestNext - System.nanoTime()) / 1_000_000;
+		final long nextCheckpointTriggerRelativeTime =
+			lastCheckpointCompletionRelativeTime + minPauseBetweenCheckpoints;
+		final long durationTillNextMillis =
+			nextCheckpointTriggerRelativeTime - clock.relativeTimeMillis();
 
 		if (durationTillNextMillis > 0) {
 			if (currentPeriodicTrigger != null) {
@@ -1269,8 +1309,7 @@ public class CheckpointCoordinator {
 	}
 
 	private long getRandomInitDelay() {
-		return ThreadLocalRandom.current().nextLong(
-			minPauseBetweenCheckpointsNanos / 1_000_000L, baseInterval + 1L);
+		return ThreadLocalRandom.current().nextLong(minPauseBetweenCheckpoints, baseInterval + 1L);
 	}
 
 	private ScheduledFuture<?> scheduleTriggerWithDelay(long initDelay) {


[flink] 04/08: [FLINK-13904][checkpointing] Make trigger thread of CheckpointCoordinator single-threaded

Posted by pn...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 5ab6261df2efb3cb34403cd76e77ca3672c57066
Author: ifndef-SleePy <mm...@gmail.com>
AuthorDate: Wed Sep 18 21:04:09 2019 +0800

    [FLINK-13904][checkpointing] Make trigger thread of CheckpointCoordinator single-threaded
---
 .../runtime/checkpoint/CheckpointCoordinator.java  |  25 +----
 .../runtime/executiongraph/ExecutionGraph.java     |  19 ++++
 .../CheckpointCoordinatorFailureTest.java          |   7 ++
 .../CheckpointCoordinatorMasterHooksTest.java      |   9 +-
 .../CheckpointCoordinatorRestoringTest.java        |  10 ++
 .../checkpoint/CheckpointCoordinatorTest.java      |  73 ++++++++++---
 .../CheckpointCoordinatorTestingUtils.java         | 121 +++++++++++++++++++++
 .../CheckpointCoordinatorTriggeringTest.java       |   9 ++
 .../checkpoint/CheckpointStateRestoreTest.java     |   9 ++
 .../FailoverStrategyCheckpointCoordinatorTest.java |  57 +---------
 .../runtime/util/TestingScheduledExecutor.java     |  62 +++++++++++
 11 files changed, 309 insertions(+), 92 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
index e7c7e17..df9278e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinator.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.runtime.checkpoint.hooks.MasterHooks;
 import org.apache.flink.runtime.concurrent.FutureUtils;
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -41,7 +42,6 @@ import org.apache.flink.runtime.state.CompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.SharedStateRegistryFactory;
 import org.apache.flink.runtime.state.StateBackend;
-import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.StringUtils;
@@ -62,7 +62,6 @@ import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledFuture;
-import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
@@ -145,8 +144,9 @@ public class CheckpointCoordinator {
 	/** The maximum number of checkpoints that may be in progress at the same time. */
 	private final int maxConcurrentCheckpointAttempts;
 
-	/** The timer that handles the checkpoint timeouts and triggers periodic checkpoints. */
-	private final ScheduledThreadPoolExecutor timer;
+	/** The timer that handles the checkpoint timeouts and triggers periodic checkpoints.
+	 * It must be single-threaded. Eventually it will be replaced by main thread executor. */
+	private final ScheduledExecutor timer;
 
 	/** The master checkpoint hooks executed by this checkpoint coordinator. */
 	private final HashMap<String, MasterTriggerRestoreHook<?>> masterHooks;
@@ -200,6 +200,7 @@ public class CheckpointCoordinator {
 			CompletedCheckpointStore completedCheckpointStore,
 			StateBackend checkpointStateBackend,
 			Executor executor,
+			ScheduledExecutor timer,
 			SharedStateRegistryFactory sharedStateRegistryFactory,
 			CheckpointFailureManager failureManager) {
 
@@ -239,13 +240,7 @@ public class CheckpointCoordinator {
 		this.recentPendingCheckpoints = new ArrayDeque<>(NUM_GHOST_CHECKPOINT_IDS);
 		this.masterHooks = new HashMap<>();
 
-		this.timer = new ScheduledThreadPoolExecutor(1,
-				new DispatcherThreadFactory(Thread.currentThread().getThreadGroup(), "Checkpoint Timer"));
-
-		// make sure the timer internally cleans up and does not hold onto stale scheduled tasks
-		this.timer.setRemoveOnCancelPolicy(true);
-		this.timer.setContinueExistingPeriodicTasksAfterShutdownPolicy(false);
-		this.timer.setExecuteExistingDelayedTasksAfterShutdownPolicy(false);
+		this.timer = timer;
 
 		this.checkpointProperties = CheckpointProperties.forCheckpoint(chkConfig.getCheckpointRetentionPolicy());
 
@@ -336,9 +331,6 @@ public class CheckpointCoordinator {
 				MasterHooks.close(masterHooks.values(), LOG);
 				masterHooks.clear();
 
-				// shut down the thread that handles the timeouts and pending triggers
-				timer.shutdownNow();
-
 				// clear and discard all pending checkpoints
 				for (PendingCheckpoint pending : pendingCheckpoints.values()) {
 					failPendingCheckpoint(pending, CheckpointFailureReason.CHECKPOINT_COORDINATOR_SHUTDOWN);
@@ -997,11 +989,6 @@ public class CheckpointCoordinator {
 		}
 	}
 
-	@VisibleForTesting
-	int getNumScheduledTasks() {
-		return timer.getQueue().size();
-	}
-
 	// --------------------------------------------------------------------------------------------
 	//  Checkpoint State Restoring
 	// --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index dab62df..a3b6374 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -45,6 +45,7 @@ import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.FutureUtils.ConjunctFuture;
+import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.execution.SuppressRestartsException;
 import org.apache.flink.runtime.executiongraph.failover.FailoverStrategy;
@@ -79,6 +80,7 @@ import org.apache.flink.runtime.shuffle.NettyShuffleMaster;
 import org.apache.flink.runtime.shuffle.ShuffleMaster;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.taskmanager.DispatcherThreadFactory;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.types.Either;
 import org.apache.flink.util.ExceptionUtils;
@@ -113,6 +115,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLongFieldUpdater;
@@ -319,8 +322,13 @@ public class ExecutionGraph implements AccessExecutionGraph {
 	// ------ Fields that are relevant to the execution and need to be cleared before archiving  -------
 
 	/** The coordinator for checkpoints, if snapshot checkpoints are enabled. */
+	@Nullable
 	private CheckpointCoordinator checkpointCoordinator;
 
+	/** TODO, replace it with main thread executor. */
+	@Nullable
+	private ScheduledExecutorService checkpointCoordinatorTimer;
+
 	/** Checkpoint stats tracker separate from the coordinator in order to be
 	 * available after archiving. */
 	private CheckpointStatsTracker checkpointStatsTracker;
@@ -604,6 +612,12 @@ public class ExecutionGraph implements AccessExecutionGraph {
 			}
 		);
 
+		checkState(checkpointCoordinatorTimer == null);
+
+		checkpointCoordinatorTimer = Executors.newSingleThreadScheduledExecutor(
+			new DispatcherThreadFactory(
+				Thread.currentThread().getThreadGroup(), "Checkpoint Timer"));
+
 		// create the coordinator that triggers and commits checkpoints and holds the state
 		checkpointCoordinator = new CheckpointCoordinator(
 			jobInformation.getJobId(),
@@ -615,6 +629,7 @@ public class ExecutionGraph implements AccessExecutionGraph {
 			checkpointStore,
 			checkpointStateBackend,
 			ioExecutor,
+			new ScheduledExecutorServiceAdapter(checkpointCoordinatorTimer),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1559,6 +1574,10 @@ public class ExecutionGraph implements AccessExecutionGraph {
 			if (coord != null) {
 				coord.shutdown(status);
 			}
+			if (checkpointCoordinatorTimer != null) {
+				checkpointCoordinatorTimer.shutdownNow();
+				checkpointCoordinatorTimer = null;
+			}
 		}
 		catch (Exception e) {
 			LOG.error("Error while cleaning up after execution", e);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index b6b7930..39a8c2e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -31,8 +31,10 @@ import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.OperatorStreamStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.TestLogger;
 
+import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -56,6 +58,10 @@ import static org.mockito.Mockito.when;
 @PrepareForTest(PendingCheckpoint.class)
 public class CheckpointCoordinatorFailureTest extends TestLogger {
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	/**
 	 * Tests that a failure while storing a completed checkpoint in the completed checkpoint store
 	 * will properly fail the originating pending checkpoint and clean upt the completed checkpoint.
@@ -93,6 +99,7 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 			new FailingCompletedCheckpointStore(),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
index 8453cba..d85d014 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorMasterHooksTest.java
@@ -32,7 +32,9 @@ import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 
+import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -69,6 +71,10 @@ import static org.mockito.Mockito.when;
  */
 public class CheckpointCoordinatorMasterHooksTest {
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	// ------------------------------------------------------------------------
 	//  hook registration
 	// ------------------------------------------------------------------------
@@ -421,7 +427,7 @@ public class CheckpointCoordinatorMasterHooksTest {
 	//  utilities
 	// ------------------------------------------------------------------------
 
-	private static CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) {
+	private CheckpointCoordinator instantiateCheckpointCoordinator(JobID jid, ExecutionVertex... ackVertices) {
 		CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
 			10000000L,
 			600000L,
@@ -441,6 +447,7 @@ public class CheckpointCoordinatorMasterHooksTest {
 				new StandaloneCompletedCheckpointStore(10),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				new CheckpointFailureManager(
 					0,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
index 1259144..1725ef2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorRestoringTest.java
@@ -39,6 +39,7 @@ import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.SerializableObject;
 import org.apache.flink.util.TestLogger;
 
@@ -89,6 +90,10 @@ import static org.mockito.Mockito.when;
 public class CheckpointCoordinatorRestoringTest extends TestLogger {
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	private CheckpointFailureManager failureManager;
 
 	@Rule
@@ -158,6 +163,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -290,6 +296,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 				store,
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -439,6 +446,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -616,6 +624,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -870,6 +879,7 @@ public class CheckpointCoordinatorRestoringTest extends TestLogger {
 			standaloneCompletedCheckpointStore,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index 2d86a06..ddd2f8e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -21,7 +21,9 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTestingUtils.TestingScheduledServiceWithRecordingScheduledTasks;
 import org.apache.flink.runtime.concurrent.Executors;
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -48,6 +50,7 @@ import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.TestLogger;
 
@@ -103,6 +106,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	private CheckpointFailureManager failureManager;
 
 	@Rule
@@ -151,6 +158,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -218,6 +226,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -276,6 +285,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -325,8 +335,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				}
 			});
 
+		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
+
 		// set up the coordinator
-		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, checkpointFailureManager);
+		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, checkpointFailureManager, scheduledExecutorService);
 
 		try {
 			// trigger the checkpoint. this should succeed
@@ -376,8 +389,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
+			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -390,7 +405,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// we have one task scheduled that will cancel after timeout
-			assertEquals(1, coord.getNumScheduledTasks());
+			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId);
@@ -427,7 +442,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertTrue(checkpoint.isDiscarded());
 
 			// the canceler is also removed
-			assertEquals(0, coord.getNumScheduledTasks());
+			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			// validate that we have no new pending checkpoint
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
@@ -464,12 +479,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
+			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, coord.getNumScheduledTasks());
+			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			// trigger the first checkpoint. this should succeed
 			assertTrue(coord.triggerCheckpoint(timestamp, false));
@@ -480,7 +497,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// validate that we have a pending checkpoint
 			assertEquals(2, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(2, coord.getNumScheduledTasks());
+			assertEquals(2, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			Iterator<Map.Entry<Long, PendingCheckpoint>> it = coord.getPendingCheckpoints().entrySet().iterator();
 			long checkpoint1Id = it.next().getKey();
@@ -527,7 +544,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// validate that we have only one pending checkpoint left
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(1, coord.getNumScheduledTasks());
+			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			// validate that it is the same second checkpoint from earlier
 			long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
@@ -570,12 +587,14 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 			ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
+			final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+				new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 			// set up the coordinator and validate the initial state
-			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager);
+			CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, coord.getNumScheduledTasks());
+			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			// trigger the first checkpoint. this should succeed
 			assertTrue(coord.triggerCheckpoint(timestamp, false));
@@ -583,7 +602,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// validate that we have a pending checkpoint
 			assertEquals(1, coord.getNumberOfPendingCheckpoints());
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(1, coord.getNumScheduledTasks());
+			assertEquals(1, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			long checkpointId = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			PendingCheckpoint checkpoint = coord.getPendingCheckpoints().get(checkpointId);
@@ -640,7 +659,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 
 			// the canceler should be removed now
-			assertEquals(0, coord.getNumScheduledTasks());
+			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			// validate that the subtasks states have registered their shared states.
 			{
@@ -672,7 +691,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
-			assertEquals(0, coord.getNumScheduledTasks());
+			assertEquals(0, scheduledExecutorService.getNumScheduledOnceTasks());
 
 			CompletedCheckpoint successNew = coord.getSuccessfulCheckpoints().get(0);
 			assertEquals(jid, successNew.getJobId());
@@ -744,6 +763,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -881,6 +901,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(10),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1051,6 +1072,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1135,6 +1157,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1205,6 +1228,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1310,8 +1334,10 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 		ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
+		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 		// set up the coordinator and validate the initial state
-		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager);
+		CheckpointCoordinator coord = getCheckpointCoordinator(jid, vertex1, vertex2, failureManager, scheduledExecutorService);
 
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -1468,6 +1494,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(10),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1568,6 +1595,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1648,6 +1676,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1731,6 +1760,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -1790,6 +1820,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1850,6 +1881,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -1895,6 +1927,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -2131,6 +2164,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -2176,6 +2210,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -2245,6 +2280,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			store,
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 				deleteExecutor -> {
 					SharedStateRegistry instance = new SharedStateRegistry(deleteExecutor);
 					createdSharedStateRegistries.add(instance);
@@ -2385,6 +2421,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		final ExecutionVertex vertex1 = mockExecutionVertex(attemptID1);
 		final ExecutionVertex vertex2 = mockExecutionVertex(attemptID2);
 
+		final TestingScheduledServiceWithRecordingScheduledTasks scheduledExecutorService =
+			new TestingScheduledServiceWithRecordingScheduledTasks(testingScheduledExecutor.getScheduledExecutor());
 		// set up the coordinator and validate the initial state
 		final CheckpointCoordinator coordinator = getCheckpointCoordinator(jobId, vertex1, vertex2,
 				new CheckpointFailureManager(
@@ -2400,7 +2438,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 						public void failJobDueToTaskFailure(Throwable cause, ExecutionAttemptID failingTask) {
 							throw new AssertionError("This method should not be called for the test.");
 						}
-					}));
+					}),
+			scheduledExecutorService);
 
 		final CompletableFuture<CompletedCheckpoint> savepointFuture = coordinator
 				.triggerSynchronousSavepoint(10L, false, "test-dir");
@@ -2430,7 +2469,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			final JobID jobId,
 			final ExecutionVertex vertex1,
 			final ExecutionVertex vertex2,
-			final CheckpointFailureManager failureManager) {
+			final CheckpointFailureManager failureManager,
+			final ScheduledExecutor timer) {
 
 		final CheckpointCoordinatorConfiguration chkConfig = new CheckpointCoordinatorConfiguration(
 				600000,
@@ -2452,6 +2492,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				timer,
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
index 9578ac7..76f042b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTestingUtils.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.checkpoint;
 import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -49,12 +50,21 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.UUID;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Delayed;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executor;
+import java.util.concurrent.ScheduledFuture;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
@@ -469,4 +479,115 @@ public class CheckpointCoordinatorTestingUtils {
 		}
 		return vertex;
 	}
+
+	static class TestingScheduledServiceWithRecordingScheduledTasks implements ScheduledExecutor {
+
+		private final ScheduledExecutor scheduledExecutor;
+
+		private final Set<UUID> tasksScheduledOnce;
+
+		public TestingScheduledServiceWithRecordingScheduledTasks(ScheduledExecutor scheduledExecutor) {
+			this.scheduledExecutor = checkNotNull(scheduledExecutor);
+			tasksScheduledOnce = new HashSet<>();
+		}
+
+		public int getNumScheduledOnceTasks() {
+			synchronized (tasksScheduledOnce) {
+				return tasksScheduledOnce.size();
+			}
+		}
+
+		@Override
+		public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
+			final UUID id = UUID.randomUUID();
+			synchronized (tasksScheduledOnce) {
+				tasksScheduledOnce.add(id);
+			}
+			return new TestingScheduledFuture<>(id, scheduledExecutor.schedule(() -> {
+				synchronized (tasksScheduledOnce) {
+					tasksScheduledOnce.remove(id);
+				}
+				command.run();
+			}, delay, unit));
+		}
+
+		@Override
+		public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
+			final UUID id = UUID.randomUUID();
+			synchronized (tasksScheduledOnce) {
+				tasksScheduledOnce.add(id);
+			}
+			return new TestingScheduledFuture<>(id, scheduledExecutor.schedule(() -> {
+				synchronized (tasksScheduledOnce) {
+					tasksScheduledOnce.remove(id);
+				}
+				return callable.call();
+			}, delay, unit));
+		}
+
+		@Override
+		public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) {
+			return scheduledExecutor.scheduleAtFixedRate(command, initialDelay, period, unit);
+		}
+
+		@Override
+		public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) {
+			return scheduledExecutor.scheduleWithFixedDelay(command, initialDelay, delay, unit);
+		}
+
+		@Override
+		public void execute(Runnable command) {
+			scheduledExecutor.execute(command);
+		}
+
+		private class TestingScheduledFuture<V> implements ScheduledFuture<V> {
+
+			private final ScheduledFuture<V> scheduledFuture;
+
+			private final UUID id;
+
+			public TestingScheduledFuture(UUID id, ScheduledFuture<V> scheduledFuture) {
+				this.id = checkNotNull(id);
+				this.scheduledFuture = checkNotNull(scheduledFuture);
+			}
+
+			@Override
+			public long getDelay(TimeUnit unit) {
+				return scheduledFuture.getDelay(unit);
+			}
+
+			@Override
+			public int compareTo(Delayed o) {
+				return scheduledFuture.compareTo(o);
+			}
+
+			@Override
+			public boolean cancel(boolean mayInterruptIfRunning) {
+				synchronized (tasksScheduledOnce) {
+					tasksScheduledOnce.remove(id);
+				}
+				return scheduledFuture.cancel(mayInterruptIfRunning);
+			}
+
+			@Override
+			public boolean isCancelled() {
+				return scheduledFuture.isCancelled();
+			}
+
+			@Override
+			public boolean isDone() {
+				return scheduledFuture.isDone();
+			}
+
+			@Override
+			public V get() throws InterruptedException, ExecutionException {
+				return scheduledFuture.get();
+			}
+
+			@Override
+			public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
+				return scheduledFuture.get(timeout, unit);
+			}
+		}
+	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
index b224d9e..a157b6f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTriggeringTest.java
@@ -28,9 +28,11 @@ import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguratio
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
@@ -53,6 +55,10 @@ import static org.mockito.Mockito.doAnswer;
 public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	private CheckpointFailureManager failureManager;
 
 	@Before
@@ -122,6 +128,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 				new StandaloneCompletedCheckpointStore(2),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -216,6 +223,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(2),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
@@ -276,6 +284,7 @@ public class CheckpointCoordinatorTriggeringTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 42849b5..080e1c7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -34,11 +34,13 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.runtime.state.testutils.TestCompletedCheckpointStorageLocation;
+import org.apache.flink.runtime.util.TestingScheduledExecutor;
 import org.apache.flink.util.SerializableObject;
 
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.Mockito;
 import org.mockito.hamcrest.MockitoHamcrest;
@@ -63,6 +65,10 @@ public class CheckpointStateRestoreTest {
 
 	private static final String TASK_MANAGER_LOCATION_INFO = "Unknown location";
 
+	@Rule
+	public final TestingScheduledExecutor testingScheduledExecutor =
+		new TestingScheduledExecutor();
+
 	private CheckpointFailureManager failureManager;
 
 	@Before
@@ -127,6 +133,7 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -211,6 +218,7 @@ public class CheckpointStateRestoreTest {
 				new StandaloneCompletedCheckpointStore(1),
 				new MemoryStateBackend(),
 				Executors.directExecutor(),
+				testingScheduledExecutor.getScheduledExecutor(),
 				SharedStateRegistry.DEFAULT_FACTORY,
 				failureManager);
 
@@ -276,6 +284,7 @@ public class CheckpointStateRestoreTest {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			testingScheduledExecutor.getScheduledExecutor(),
 			SharedStateRegistry.DEFAULT_FACTORY,
 			failureManager);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
index 17a5bcc..df3e97e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/FailoverStrategyCheckpointCoordinatorTest.java
@@ -19,10 +19,8 @@
 package org.apache.flink.runtime.checkpoint;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.mock.Whitebox;
 import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutor;
-import org.apache.flink.runtime.concurrent.ScheduledExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
@@ -33,17 +31,12 @@ import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguratio
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.util.TestLogger;
-import org.apache.flink.util.concurrent.NeverCompleteFuture;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import java.util.concurrent.Callable;
-import java.util.concurrent.ScheduledFuture;
-import java.util.concurrent.ScheduledThreadPoolExecutor;
 import java.util.concurrent.ThreadLocalRandom;
-import java.util.concurrent.TimeUnit;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -90,19 +83,13 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
 			new StandaloneCompletedCheckpointStore(1),
 			new MemoryStateBackend(),
 			Executors.directExecutor(),
+			manualThreadExecutor,
 			SharedStateRegistry.DEFAULT_FACTORY,
 			mock(CheckpointFailureManager.class));
 
 		// switch current execution's state to running to allow checkpoint could be triggered.
 		mockExecutionRunning(executionVertex);
 
-		// use manual checkpoint timer to trigger period checkpoints as we expect.
-		ManualCheckpointTimer manualCheckpointTimer = new ManualCheckpointTimer(manualThreadExecutor);
-		// set the init delay as 0 to ensure first checkpoint could be triggered once we trigger the manual executor
-		// this is used to avoid the randomness of when to trigger the first checkpoint (introduced via FLINK-9352)
-		manualCheckpointTimer.setManualDelay(0L);
-		Whitebox.setInternalState(checkpointCoordinator, "timer", manualCheckpointTimer);
-
 		checkpointCoordinator.startCheckpointScheduler();
 		assertTrue(checkpointCoordinator.isCurrentPeriodicTriggerAvailable());
 		manualThreadExecutor.triggerAll();
@@ -140,46 +127,4 @@ public class FailoverStrategyCheckpointCoordinatorTest extends TestLogger {
 	private void mockExecutionRunning(ExecutionVertex executionVertex) {
 		when(executionVertex.getCurrentExecutionAttempt().getState()).thenReturn(ExecutionState.RUNNING);
 	}
-
-	public static class ManualCheckpointTimer extends ScheduledThreadPoolExecutor {
-		private final ScheduledExecutor scheduledExecutor;
-		private long manualDelay = 0;
-
-		ManualCheckpointTimer(final ScheduledExecutor scheduledExecutor) {
-			super(0);
-			this.scheduledExecutor = scheduledExecutor;
-		}
-
-		void setManualDelay(long manualDelay) {
-			this.manualDelay = manualDelay;
-		}
-
-		@Override
-		public ScheduledFuture<?> schedule(Runnable command, long delay, TimeUnit unit) {
-			// used as checkpoint canceller, as we don't want pending checkpoint cancelled, this should never be scheduled.
-			return new NeverCompleteFuture(delay);
-		}
-
-		@Override
-		public <V> ScheduledFuture<V> schedule(Callable<V> callable, long delay, TimeUnit unit) {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public ScheduledFuture<?> scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) {
-			// used to schedule periodic checkpoints.
-			// this would use configured 'manualDelay' to let the task schedule with the wanted delay.
-			return scheduledExecutor.scheduleWithFixedDelay(command, manualDelay, period, unit);
-		}
-
-		@Override
-		public ScheduledFuture<?> scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) {
-			throw new UnsupportedOperationException();
-		}
-
-		@Override
-		public void execute(Runnable command) {
-			scheduledExecutor.execute(command);
-		}
-	}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java
new file mode 100644
index 0000000..930af9f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/TestingScheduledExecutor.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.util;
+
+import org.apache.flink.runtime.concurrent.ScheduledExecutor;
+import org.apache.flink.runtime.concurrent.ScheduledExecutorServiceAdapter;
+import org.apache.flink.util.ExecutorUtils;
+
+import org.junit.rules.ExternalResource;
+
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * Provide an automatically shut down scheduled executor for testing.
+ */
+public class TestingScheduledExecutor extends ExternalResource {
+
+	private long shutdownTimeoutMillis;
+	private ScheduledExecutor scheduledExecutor;
+	private ScheduledExecutorService innerExecutorService;
+
+	public TestingScheduledExecutor() {
+			this(500L);
+		}
+
+	public TestingScheduledExecutor(long shutdownTimeoutMillis) {
+		this.shutdownTimeoutMillis = shutdownTimeoutMillis;
+	}
+
+	@Override
+	protected void before() {
+		this.innerExecutorService = Executors.newSingleThreadScheduledExecutor();
+		this.scheduledExecutor = new ScheduledExecutorServiceAdapter(innerExecutorService);
+	}
+
+	@Override
+	protected void after() {
+		ExecutorUtils.gracefulShutdown(shutdownTimeoutMillis, TimeUnit.MILLISECONDS, innerExecutorService);
+	}
+
+	public ScheduledExecutor getScheduledExecutor() {
+		return scheduledExecutor;
+	}
+}