You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2020/06/10 05:27:00 UTC

[flink] branch release-1.11 updated: [FLINK-18136][checkpointing] Don't start channel state writer for savepoint

This is an automated email from the ASF dual-hosted git repository.

zhijiang pushed a commit to branch release-1.11
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.11 by this push:
     new 7f410fa  [FLINK-18136][checkpointing] Don't start channel state writer for savepoint
7f410fa is described below

commit 7f410fa4f4758e48fcc6c33d8354c21ab02e2dc6
Author: Roman <kh...@gmail.com>
AuthorDate: Tue Jun 9 10:58:46 2020 +0200

    [FLINK-18136][checkpointing] Don't start channel state writer for savepoint
    
    ChannelStateWriter#start should be only called for unaligned checkpoint. While source triggering
    savepoint, SubtaskCheckpointCoordinator#initCheckpoint is introduced to judge the condition
    whether to start the internal writer or not. And this new method is also used in other places like
    CheckpointBarrierUnaligner.
    
    This closes #12489.
---
 .../state/TestCheckpointStorageWorkerView.java     | 52 +++++++++++++
 .../runtime/io/CheckpointBarrierUnaligner.java     | 15 ++--
 .../streaming/runtime/io/InputProcessorUtil.java   | 14 ++--
 .../runtime/tasks/MultipleInputStreamTask.java     |  2 +-
 .../runtime/tasks/OneInputStreamTask.java          |  2 +-
 .../flink/streaming/runtime/tasks/StreamTask.java  |  7 +-
 .../tasks/SubtaskCheckpointCoordinator.java        | 10 ++-
 .../tasks/SubtaskCheckpointCoordinatorImpl.java    | 50 ++++++++++--
 .../runtime/tasks/TwoInputStreamTask.java          |  2 +-
 .../AlternatingCheckpointBarrierHandlerTest.java   | 12 +--
 ...CheckpointBarrierUnalignerCancellationTest.java |  4 +-
 .../runtime/io/CheckpointBarrierUnalignerTest.java | 14 ++--
 .../runtime/io/InputProcessorUtilTest.java         |  3 +-
 .../runtime/io/StreamTaskNetworkInputTest.java     |  5 +-
 .../tasks/SubtaskCheckpointCoordinatorTest.java    | 47 ++++++++++++
 .../tasks/TestSubtaskCheckpointCoordinator.java    | 89 ++++++++++++++++++++++
 16 files changed, 282 insertions(+), 46 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestCheckpointStorageWorkerView.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestCheckpointStorageWorkerView.java
new file mode 100644
index 0000000..c50a390
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TestCheckpointStorageWorkerView.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+
+import java.io.IOException;
+
+/**
+ * Non-persistent {@link CheckpointStorageWorkerView} for tests. Uses {@link MemCheckpointStreamFactory}.
+ */
+public class TestCheckpointStorageWorkerView implements CheckpointStorageWorkerView {
+
+	private final int maxStateSize;
+	private final MemCheckpointStreamFactory taskOwnedCheckpointStreamFactory;
+	private final CheckpointedStateScope taskOwnedStateScope;
+
+	public TestCheckpointStorageWorkerView(int maxStateSize) {
+		this(maxStateSize, CheckpointedStateScope.EXCLUSIVE);
+	}
+
+	private TestCheckpointStorageWorkerView(int maxStateSize, CheckpointedStateScope taskOwnedStateScope) {
+		this.maxStateSize = maxStateSize;
+		this.taskOwnedCheckpointStreamFactory = new MemCheckpointStreamFactory(maxStateSize);
+		this.taskOwnedStateScope = taskOwnedStateScope;
+	}
+
+	@Override
+	public CheckpointStreamFactory resolveCheckpointStorageLocation(long checkpointId, CheckpointStorageLocationReference reference) {
+		return new MemCheckpointStreamFactory(maxStateSize);
+	}
+
+	@Override
+	public CheckpointStreamFactory.CheckpointStateOutputStream createTaskOwnedStateStream() throws IOException {
+		return taskOwnedCheckpointStreamFactory.createCheckpointStateOutputStream(taskOwnedStateScope);
+	}
+}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
index 817e73d..f978e72 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnaligner.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferReceivedListener;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.streaming.runtime.tasks.SubtaskCheckpointCoordinator;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -91,7 +92,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 
 	CheckpointBarrierUnaligner(
 			int[] numberOfInputChannelsPerGate,
-			ChannelStateWriter channelStateWriter,
+			SubtaskCheckpointCoordinator checkpointCoordinator,
 			String taskName,
 			AbstractInvokable toNotifyOnCheckpoint) {
 		super(toNotifyOnCheckpoint);
@@ -114,7 +115,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 			.flatMap(Function.identity())
 			.toArray(InputChannelInfo[]::new);
 
-		threadSafeUnaligner = new ThreadSafeUnaligner(totalNumChannels,	checkNotNull(channelStateWriter), this);
+		threadSafeUnaligner = new ThreadSafeUnaligner(totalNumChannels,	checkNotNull(checkpointCoordinator), this);
 	}
 
 	/**
@@ -276,14 +277,14 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 
 		private int numOpenChannels;
 
-		private final ChannelStateWriter channelStateWriter;
+		private final SubtaskCheckpointCoordinator checkpointCoordinator;
 
 		private final CheckpointBarrierUnaligner handler;
 
-		ThreadSafeUnaligner(int totalNumChannels, ChannelStateWriter channelStateWriter, CheckpointBarrierUnaligner handler) {
+		ThreadSafeUnaligner(int totalNumChannels, SubtaskCheckpointCoordinator checkpointCoordinator, CheckpointBarrierUnaligner handler) {
 			this.numOpenChannels = totalNumChannels;
 			this.storeNewBuffers = new boolean[totalNumChannels];
-			this.channelStateWriter = channelStateWriter;
+			this.checkpointCoordinator = checkpointCoordinator;
 			this.handler = handler;
 		}
 
@@ -313,7 +314,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 		@Override
 		public synchronized void notifyBufferReceived(Buffer buffer, InputChannelInfo channelInfo) {
 			if (storeNewBuffers[handler.getFlattenedChannelIndex(channelInfo)]) {
-				channelStateWriter.addInputData(
+				checkpointCoordinator.getChannelStateWriter().addInputData(
 					currentReceivedCheckpointId,
 					channelInfo,
 					ChannelStateWriter.SEQUENCE_NUMBER_UNKNOWN,
@@ -351,7 +352,7 @@ public class CheckpointBarrierUnaligner extends CheckpointBarrierHandler {
 			Arrays.fill(storeNewBuffers, true);
 			numBarriersReceived = 0;
 			allBarriersReceivedFuture = new CompletableFuture<>();
-			channelStateWriter.start(barrierId, barrier.getCheckpointOptions());
+			checkpointCoordinator.initCheckpoint(barrierId, barrier.getCheckpointOptions());
 		}
 
 		synchronized CompletableFuture<Void> getAllBarriersReceivedFuture(long checkpointId) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
index 5dbfc02..b397b9e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/InputProcessorUtil.java
@@ -18,13 +18,13 @@
 package org.apache.flink.streaming.runtime.io;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.metrics.MetricNames;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.tasks.SubtaskCheckpointCoordinator;
 
 import java.util.Arrays;
 import java.util.Collection;
@@ -43,7 +43,7 @@ public class InputProcessorUtil {
 	public static CheckpointedInputGate createCheckpointedInputGate(
 			AbstractInvokable toNotifyOnCheckpoint,
 			StreamConfig config,
-			ChannelStateWriter channelStateWriter,
+			SubtaskCheckpointCoordinator checkpointCoordinator,
 			IndexedInputGate[] inputGates,
 			TaskIOMetricGroup taskIOMetricGroup,
 			String taskName) {
@@ -51,7 +51,7 @@ public class InputProcessorUtil {
 		CheckpointBarrierHandler barrierHandler = createCheckpointBarrierHandler(
 			config,
 			Arrays.stream(inputGates).mapToInt(InputGate::getNumberOfInputChannels),
-			channelStateWriter,
+			checkpointCoordinator,
 			taskName,
 			generateChannelIndexToInputGateMap(inputGate),
 			generateInputGateToChannelIndexOffsetMap(inputGate),
@@ -70,7 +70,7 @@ public class InputProcessorUtil {
 	public static CheckpointedInputGate[] createCheckpointedMultipleInputGate(
 			AbstractInvokable toNotifyOnCheckpoint,
 			StreamConfig config,
-			ChannelStateWriter channelStateWriter,
+			SubtaskCheckpointCoordinator checkpointCoordinator,
 			TaskIOMetricGroup taskIOMetricGroup,
 			String taskName,
 			Collection<IndexedInputGate> ...inputGates) {
@@ -100,7 +100,7 @@ public class InputProcessorUtil {
 		CheckpointBarrierHandler barrierHandler = createCheckpointBarrierHandler(
 			config,
 			numberOfInputChannelsPerGate,
-			channelStateWriter,
+			checkpointCoordinator,
 			taskName,
 			generateChannelIndexToInputGateMap(unionedInputGates),
 			inputGateToChannelIndexOffset,
@@ -126,7 +126,7 @@ public class InputProcessorUtil {
 	private static CheckpointBarrierHandler createCheckpointBarrierHandler(
 			StreamConfig config,
 			IntStream numberOfInputChannelsPerGate,
-			ChannelStateWriter channelStateWriter,
+			SubtaskCheckpointCoordinator checkpointCoordinator,
 			String taskName,
 			InputGate[] channelIndexToInputGate,
 			Map<InputGate, Integer> inputGateToChannelIndexOffset,
@@ -142,7 +142,7 @@ public class InputProcessorUtil {
 							toNotifyOnCheckpoint),
 						new CheckpointBarrierUnaligner(
 							numberOfInputChannelsPerGate.toArray(),
-							channelStateWriter,
+							checkpointCoordinator,
 							taskName,
 							toNotifyOnCheckpoint),
 						toNotifyOnCheckpoint);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
index 15f620b..1578323 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
@@ -94,7 +94,7 @@ public class MultipleInputStreamTask<OUT> extends StreamTask<OUT, MultipleInputS
 		CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate(
 			this,
 			getConfiguration(),
-			getChannelStateWriter(),
+			getCheckpointCoordinator(),
 			getEnvironment().getMetricGroup().getIOMetricGroup(),
 			getTaskNameWithSubtaskAndId(),
 			inputGates);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
index 4411190..cd6cc15 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTask.java
@@ -104,7 +104,7 @@ public class OneInputStreamTask<IN, OUT> extends StreamTask<OUT, OneInputStreamO
 		return InputProcessorUtil.createCheckpointedInputGate(
 			this,
 			configuration,
-			getChannelStateWriter(),
+			getCheckpointCoordinator(),
 			inputGates,
 			getEnvironment().getMetricGroup().getIOMetricGroup(),
 			getTaskNameWithSubtaskAndId());
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index d9add56..a7d04d2 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -315,8 +315,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 		return inputProcessor.prepareSnapshot(channelStateWriter, checkpointId);
 	}
 
-	protected ChannelStateWriter getChannelStateWriter() {
-		return subtaskCheckpointCoordinator.getChannelStateWriter();
+	SubtaskCheckpointCoordinator getCheckpointCoordinator() {
+		return subtaskCheckpointCoordinator;
 	}
 
 	// ------------------------------------------------------------------------
@@ -808,7 +808,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			// No alignment if we inject a checkpoint
 			CheckpointMetrics checkpointMetrics = new CheckpointMetrics().setAlignmentDurationNanos(0L);
 
-			subtaskCheckpointCoordinator.getChannelStateWriter().start(checkpointMetaData.getCheckpointId(), checkpointOptions);
+			subtaskCheckpointCoordinator.initCheckpoint(checkpointMetaData.getCheckpointId(), checkpointOptions);
+
 			boolean success = performCheckpoint(checkpointMetaData, checkpointOptions, checkpointMetrics, advanceToEndOfEventTime);
 			if (!success) {
 				declineCheckpoint(checkpointMetaData.getCheckpointId());
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinator.java
index 2922735..56e9a50 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinator.java
@@ -39,7 +39,12 @@ import java.util.function.Supplier;
  * </ol>
  */
 @Internal
-interface SubtaskCheckpointCoordinator extends Closeable {
+public interface SubtaskCheckpointCoordinator extends Closeable {
+
+	/**
+	 * Initialize new checkpoint.
+	 */
+	void initCheckpoint(long id, CheckpointOptions checkpointOptions);
 
 	ChannelStateWriter getChannelStateWriter();
 
@@ -47,6 +52,9 @@ interface SubtaskCheckpointCoordinator extends Closeable {
 
 	void abortCheckpointOnBarrier(long checkpointId, Throwable cause, OperatorChain<?, ?> operatorChain) throws IOException;
 
+	/**
+	 * Must be called after {@link #initCheckpoint(long, CheckpointOptions)}.
+	 */
 	void checkpointState(
 		CheckpointMetaData checkpointMetaData,
 		CheckpointOptions checkpointOptions,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
index 9fc7927..d14594f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
@@ -69,7 +69,6 @@ import java.util.concurrent.Future;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
-import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
@@ -122,7 +121,6 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 			DEFAULT_MAX_RECORD_ABORTED_CHECKPOINTS);
 	}
 
-	@VisibleForTesting
 	SubtaskCheckpointCoordinatorImpl(
 			CheckpointStorageWorkerView checkpointStorage,
 			String taskName,
@@ -134,6 +132,33 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 			boolean unalignedCheckpointEnabled,
 			BiFunctionWithException<ChannelStateWriter, Long, CompletableFuture<Void>, IOException> prepareInputSnapshot,
 			int maxRecordAbortedCheckpoints) throws IOException {
+		this(
+			checkpointStorage,
+			taskName,
+			actionExecutor,
+			closeableRegistry,
+			executorService,
+			env,
+			asyncExceptionHandler,
+			unalignedCheckpointEnabled,
+			prepareInputSnapshot,
+			maxRecordAbortedCheckpoints,
+			unalignedCheckpointEnabled ? openChannelStateWriter(taskName, checkpointStorage) : ChannelStateWriter.NO_OP);
+	}
+
+	@VisibleForTesting
+	SubtaskCheckpointCoordinatorImpl(
+			CheckpointStorageWorkerView checkpointStorage,
+			String taskName,
+			StreamTaskActionExecutor actionExecutor,
+			CloseableRegistry closeableRegistry,
+			ExecutorService executorService,
+			Environment env,
+			AsyncExceptionHandler asyncExceptionHandler,
+			boolean unalignedCheckpointEnabled,
+			BiFunctionWithException<ChannelStateWriter, Long, CompletableFuture<Void>, IOException> prepareInputSnapshot,
+			int maxRecordAbortedCheckpoints,
+			ChannelStateWriter channelStateWriter) throws IOException {
 		this.checkpointStorage = new CachingCheckpointStorageWorkerView(checkNotNull(checkpointStorage));
 		this.taskName = checkNotNull(taskName);
 		this.checkpoints = new HashMap<>();
@@ -142,7 +167,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 		this.env = checkNotNull(env);
 		this.asyncExceptionHandler = checkNotNull(asyncExceptionHandler);
 		this.actionExecutor = checkNotNull(actionExecutor);
-		this.channelStateWriter = unalignedCheckpointEnabled ? openChannelStateWriter() : ChannelStateWriter.NO_OP;
+		this.channelStateWriter = checkNotNull(channelStateWriter);
 		this.unalignedCheckpointEnabled = unalignedCheckpointEnabled;
 		this.prepareInputSnapshot = prepareInputSnapshot;
 		this.abortedCheckpointIds = createAbortedCheckpointSetWithLimitSize(maxRecordAbortedCheckpoints);
@@ -151,7 +176,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 		this.closed = false;
 	}
 
-	private ChannelStateWriter openChannelStateWriter() {
+	private static ChannelStateWriter openChannelStateWriter(String taskName, CheckpointStorageWorkerView checkpointStorage) {
 		ChannelStateWriterImpl writer = new ChannelStateWriterImpl(taskName, checkpointStorage);
 		writer.open();
 		return writer;
@@ -217,7 +242,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 			unalignedCheckpointEnabled);
 
 		// Step (3): Prepare to spill the in-flight buffers for input and output
-		if (unalignedCheckpointEnabled && !options.getCheckpointType().isSavepoint()) {
+		if (includeChannelState(options)) {
 			prepareInflightDataSnapshot(metadata.getCheckpointId());
 		}
 
@@ -284,6 +309,17 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 	}
 
 	@Override
+	public void initCheckpoint(long id, CheckpointOptions checkpointOptions) {
+		if (includeChannelState(checkpointOptions)) {
+			channelStateWriter.start(id, checkpointOptions);
+		}
+	}
+
+	private boolean includeChannelState(CheckpointOptions checkpointOptions) {
+		return unalignedCheckpointEnabled && !checkpointOptions.getCheckpointType().isSavepoint();
+	}
+
+	@Override
 	public void close() throws IOException {
 		List<AsyncCheckpointRunnable> asyncCheckpointRunnables = null;
 		synchronized (lock) {
@@ -397,7 +433,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 
 	private void finishAndReportAsync(Map<OperatorID, OperatorSnapshotFutures> snapshotFutures, CheckpointMetaData metadata, CheckpointMetrics metrics, CheckpointOptions options) {
 		final Future<?> channelWrittenFuture;
-		if (unalignedCheckpointEnabled && !options.getCheckpointType().isSavepoint()) {
+		if (includeChannelState(options)) {
 			ChannelStateWriteResult writeResult = channelStateWriter.getWriteResult(metadata.getCheckpointId());
 			channelWrittenFuture = CompletableFuture.allOf(
 					writeResult.getInputChannelStateHandles(),
@@ -453,7 +489,7 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
 		long checkpointId = checkpointMetaData.getCheckpointId();
 		long started = System.nanoTime();
 
-		ChannelStateWriteResult channelStateWriteResult = checkpointOptions.getCheckpointType() == CHECKPOINT ?
+		ChannelStateWriteResult channelStateWriteResult = includeChannelState(checkpointOptions) ?
 								channelStateWriter.getWriteResult(checkpointId) :
 								ChannelStateWriteResult.EMPTY;
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java
index c39d9e4..88d1246 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTask.java
@@ -57,7 +57,7 @@ public class TwoInputStreamTask<IN1, IN2, OUT> extends AbstractTwoInputStreamTas
 		CheckpointedInputGate[] checkpointedInputGates = InputProcessorUtil.createCheckpointedMultipleInputGate(
 			this,
 			getConfiguration(),
-			getChannelStateWriter(),
+			getCheckpointCoordinator(),
 			getEnvironment().getMetricGroup().getIOMetricGroup(),
 			getTaskNameWithSubtaskAndId(),
 			inputGates1,
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
index d723d06..16a6bf2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/AlternatingCheckpointBarrierHandlerTest.java
@@ -21,7 +21,6 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -33,6 +32,7 @@ import org.apache.flink.runtime.io.network.partition.consumer.TestInputChannel;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator;
 import org.apache.flink.util.function.ThrowingRunnable;
 
 import org.junit.Test;
@@ -89,7 +89,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
 		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, ChannelStateWriter.NO_OP, "test", target);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		for (int i = 0; i < 4; i++) {
@@ -119,7 +119,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
 		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, ChannelStateWriter.NO_OP, "test", target);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		final long id = 1;
@@ -135,7 +135,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		inputGate.setInputChannels(new TestInputChannel(inputGate, 0), new TestInputChannel(inputGate, 1));
 		TestInvokable target = new TestInvokable();
 		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate, inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, ChannelStateWriter.NO_OP, "test", target);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 
 		long checkpointId = 10;
@@ -157,7 +157,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		SingleInputGate inputGate = new SingleInputGateBuilder().setNumberOfChannels(totalChannels).build();
 		TestInvokable target = new TestInvokable();
 		CheckpointBarrierAligner alignedHandler = new CheckpointBarrierAligner("test", new InputGate[]{inputGate}, singletonMap(inputGate, 0), target);
-		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, ChannelStateWriter.NO_OP, "test", target);
+		CheckpointBarrierUnaligner unalignedHandler = new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", target);
 		AlternatingCheckpointBarrierHandler barrierHandler = new AlternatingCheckpointBarrierHandler(alignedHandler, unalignedHandler, target);
 		for (int i = 0; i < closedChannels; i++) {
 			barrierHandler.processEndOfPartition();
@@ -206,7 +206,7 @@ public class AlternatingCheckpointBarrierHandlerTest {
 		Arrays.fill(channelIndexToInputGate, inputGate);
 		return new AlternatingCheckpointBarrierHandler(
 			new CheckpointBarrierAligner(taskName, channelIndexToInputGate, singletonMap(inputGate, 0), target),
-			new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, ChannelStateWriter.NO_OP, taskName, target),
+			new CheckpointBarrierUnaligner(new int[]{inputGate.getNumberOfInputChannels()}, TestSubtaskCheckpointCoordinator.INSTANCE, taskName, target),
 			target);
 	}
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
index e202d60..1bb4972 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerCancellationTest.java
@@ -20,12 +20,12 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.event.RuntimeEvent;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
+import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator;
 import org.apache.flink.util.function.ThrowingRunnable;
 
 import org.junit.Test;
@@ -77,7 +77,7 @@ public class CheckpointBarrierUnalignerCancellationTest {
 	@Test
 	public void test() throws Exception {
 		TestInvokable invokable = new TestInvokable();
-		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(new int[]{numChannels}, ChannelStateWriter.NO_OP, "test", invokable);
+		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(new int[]{numChannels}, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 
 		for (RuntimeEvent e : events) {
 			if (e instanceof CancelCheckpointMarker) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
index 2e5a61d..0ab8ee2 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/CheckpointBarrierUnalignerTest.java
@@ -20,7 +20,6 @@ package org.apache.flink.streaming.runtime.io;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.RecordingChannelStateWriter;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
@@ -41,6 +40,7 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.streaming.runtime.io.CheckpointBarrierUnaligner.ThreadSafeUnaligner;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator;
 import org.apache.flink.streaming.runtime.tasks.mailbox.MailboxDefaultAction;
 import org.apache.flink.util.function.ThrowingRunnable;
 
@@ -488,7 +488,7 @@ public class CheckpointBarrierUnalignerTest {
 	@Test
 	public void testConcurrentProcessBarrierAndNotifyBarrierReceived() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
-		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(new int[] { 1 }, ChannelStateWriter.NO_OP, "test", invokable);
+		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 		final InputChannelInfo channelInfo = new InputChannelInfo(0, 0);
 		final ExecutorService executor = Executors.newFixedThreadPool(1);
 
@@ -524,7 +524,7 @@ public class CheckpointBarrierUnalignerTest {
 	public void testProcessCancellationBarrierAfterNotifyBarrierReceived() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
 		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, ChannelStateWriter.NO_OP, "test", invokable);
+			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 
 		ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner();
 		// should trigger respective checkpoint
@@ -547,7 +547,7 @@ public class CheckpointBarrierUnalignerTest {
 	public void testProcessCancellationBarrierAfterProcessBarrier() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
 		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, ChannelStateWriter.NO_OP, "test", invokable);
+			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 
 		// should trigger respective checkpoint
 		handler.processBarrier(buildCheckpointBarrier(DEFAULT_CHECKPOINT_ID), 0);
@@ -564,7 +564,7 @@ public class CheckpointBarrierUnalignerTest {
 	public void testProcessCancellationBarrierBeforeProcessAndReceiveBarrier() throws Exception {
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
 		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { 1 }, ChannelStateWriter.NO_OP, "test", invokable);
+			new int[] { 1 }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 
 		handler.processCancellationBarrier(new CancelCheckpointMarker(DEFAULT_CHECKPOINT_ID));
 
@@ -609,7 +609,7 @@ public class CheckpointBarrierUnalignerTest {
 		final int numberOfChannels = 2;
 		final ValidatingCheckpointInvokable invokable = new ValidatingCheckpointInvokable();
 		final CheckpointBarrierUnaligner handler = new CheckpointBarrierUnaligner(
-			new int[] { numberOfChannels }, ChannelStateWriter.NO_OP, "test", invokable);
+			new int[] { numberOfChannels }, TestSubtaskCheckpointCoordinator.INSTANCE, "test", invokable);
 
 		ThreadSafeUnaligner unaligner = handler.getThreadSafeUnaligner();
 		// should trigger respective checkpoint
@@ -705,7 +705,7 @@ public class CheckpointBarrierUnalignerTest {
 	private CheckpointedInputGate createCheckpointedInputGate(InputGate gate, AbstractInvokable toNotify) {
 		final CheckpointBarrierUnaligner barrierHandler = new CheckpointBarrierUnaligner(
 			new int[]{ gate.getNumberOfInputChannels() },
-			channelStateWriter,
+			new TestSubtaskCheckpointCoordinator(channelStateWriter),
 			"Test",
 			toNotify);
 		barrierHandler.getBufferReceivedListener().ifPresent(gate::registerBufferReceivedListener);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
index 8a511ef..37a55d6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/InputProcessorUtilTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironmentBuilder;
 import org.apache.flink.streaming.api.CheckpointingMode;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator;
 import org.apache.flink.streaming.util.MockStreamTask;
 import org.apache.flink.streaming.util.MockStreamTaskBuilder;
 
@@ -96,7 +97,7 @@ public class InputProcessorUtilTest {
 			CheckpointedInputGate[] checkpointedMultipleInputGate = InputProcessorUtil.createCheckpointedMultipleInputGate(
 				streamTask,
 				streamConfig,
-				new MockChannelStateWriter(),
+				new TestSubtaskCheckpointCoordinator(new MockChannelStateWriter()),
 				environment.getMetricGroup().getIOMetricGroup(),
 				streamTask.getName(),
 				inputGates);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
index 0793bf1..b61da52 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/StreamTaskNetworkInputTest.java
@@ -52,6 +52,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamElementSerializer;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.streamstatus.StatusWatermarkValve;
 import org.apache.flink.streaming.runtime.streamstatus.StreamStatus;
+import org.apache.flink.streaming.runtime.tasks.TestSubtaskCheckpointCoordinator;
 
 import org.junit.After;
 import org.junit.Test;
@@ -122,7 +123,7 @@ public class StreamTaskNetworkInputTest {
 	public void testSnapshotWithTwoInputGates() throws Exception {
 		CheckpointBarrierUnaligner unaligner = new CheckpointBarrierUnaligner(
 				new int[]{ 1, 1 },
-				ChannelStateWriter.NO_OP,
+				TestSubtaskCheckpointCoordinator.INSTANCE,
 				"test",
 				new DummyCheckpointInvokable());
 
@@ -194,7 +195,7 @@ public class StreamTaskNetworkInputTest {
 				inputGate.getInputGate(),
 				new CheckpointBarrierUnaligner(
 					new int[] { numInputChannels },
-					ChannelStateWriter.NO_OP,
+					TestSubtaskCheckpointCoordinator.INSTANCE,
 					"test",
 					new DummyCheckpointInvokable())),
 			inSerializer,
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
index 94bb92f..580ab54 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
@@ -18,11 +18,15 @@
 
 package org.apache.flink.streaming.runtime.tasks;
 
+import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
 import org.apache.flink.runtime.io.network.api.writer.NonRecordWriter;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
@@ -32,6 +36,7 @@ import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.DoneFuture;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.runtime.state.TestCheckpointStorageWorkerView;
 import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -46,6 +51,7 @@ import org.apache.flink.util.ExceptionUtils;
 
 import org.junit.Test;
 
+import java.io.IOException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
@@ -53,7 +59,9 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.RunnableFuture;
 import java.util.concurrent.TimeUnit;
 
+import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.SAVEPOINT;
+import static org.apache.flink.shaded.guava18.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -65,6 +73,29 @@ import static org.junit.Assert.fail;
 public class SubtaskCheckpointCoordinatorTest {
 
 	@Test
+	public void testInitCheckpoint() throws IOException {
+		assertTrue(initCheckpoint(true, CHECKPOINT));
+		assertFalse(initCheckpoint(true, SAVEPOINT));
+		assertFalse(initCheckpoint(false, CHECKPOINT));
+		assertFalse(initCheckpoint(false, SAVEPOINT));
+	}
+
+	private boolean initCheckpoint(boolean unalignedCheckpointEnabled, CheckpointType checkpointType) throws IOException {
+		class MockWriter extends ChannelStateWriterImpl.NoOpChannelStateWriter {
+			private boolean started;
+			@Override
+			public void start(long checkpointId, CheckpointOptions checkpointOptions) {
+				started = true;
+			}
+		}
+
+		MockWriter writer = new MockWriter();
+		SubtaskCheckpointCoordinator coordinator = coordinator(unalignedCheckpointEnabled, writer);
+		coordinator.initCheckpoint(1L, new CheckpointOptions(checkpointType, CheckpointStorageLocationReference.getDefault()));
+		return writer.started;
+	}
+
+	@Test
 	public void testNotifyCheckpointComplete() throws Exception {
 		TestTaskStateManager stateManager = new TestTaskStateManager();
 		MockEnvironment mockEnvironment = MockEnvironment.builder().setTaskStateManager(stateManager).build();
@@ -373,4 +404,20 @@ public class SubtaskCheckpointCoordinatorTest {
 		public void processLatencyMarker(LatencyMarker latencyMarker) {
 		}
 	}
+
+	private static SubtaskCheckpointCoordinator coordinator(boolean unalignedCheckpointEnabled, ChannelStateWriter channelStateWriter) throws IOException {
+		return new SubtaskCheckpointCoordinatorImpl(
+			new TestCheckpointStorageWorkerView(100),
+			"test",
+			StreamTaskActionExecutor.IMMEDIATE,
+			new CloseableRegistry(),
+			newDirectExecutorService(),
+			new DummyEnvironment(),
+			(message, unused) -> fail(message),
+			unalignedCheckpointEnabled,
+			(unused1, unused2) -> CompletableFuture.completedFuture(null),
+			0,
+			channelStateWriter
+		);
+	}
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TestSubtaskCheckpointCoordinator.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TestSubtaskCheckpointCoordinator.java
new file mode 100644
index 0000000..3897ab1
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TestSubtaskCheckpointCoordinator.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.tasks;
+
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
+import org.apache.flink.runtime.state.TestCheckpointStorageWorkerView;
+
+import java.util.function.Supplier;
+
+/**
+ * {@link SubtaskCheckpointCoordinator} implementation for tests.
+ */
+public class TestSubtaskCheckpointCoordinator implements SubtaskCheckpointCoordinator {
+
+	public static final TestSubtaskCheckpointCoordinator INSTANCE = new TestSubtaskCheckpointCoordinator();
+
+	private static final int DEFAULT_MAX_STATE_SIZE = 1000;
+
+	private final CheckpointStorageWorkerView storageWorkerView;
+	private final ChannelStateWriter channelStateWriter;
+
+	private TestSubtaskCheckpointCoordinator() {
+		this(new TestCheckpointStorageWorkerView(DEFAULT_MAX_STATE_SIZE), ChannelStateWriter.NO_OP);
+	}
+
+	public TestSubtaskCheckpointCoordinator(ChannelStateWriter channelStateWriter) {
+		this(new TestCheckpointStorageWorkerView(DEFAULT_MAX_STATE_SIZE), channelStateWriter);
+	}
+
+	private TestSubtaskCheckpointCoordinator(CheckpointStorageWorkerView storageWorkerView, ChannelStateWriter channelStateWriter) {
+		this.storageWorkerView = storageWorkerView;
+		this.channelStateWriter = channelStateWriter;
+	}
+
+	@Override
+	public void initCheckpoint(long id, CheckpointOptions checkpointOptions) {
+		channelStateWriter.start(id, checkpointOptions);
+	}
+
+	@Override
+	public ChannelStateWriter getChannelStateWriter() {
+		return channelStateWriter;
+	}
+
+	@Override
+	public CheckpointStorageWorkerView getCheckpointStorage() {
+		return storageWorkerView;
+	}
+
+	@Override
+	public void abortCheckpointOnBarrier(long checkpointId, Throwable cause, OperatorChain<?, ?> operatorChain) {
+		channelStateWriter.abort(checkpointId, cause);
+	}
+
+	@Override
+	public void checkpointState(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions, CheckpointMetrics checkpointMetrics, OperatorChain<?, ?> operatorChain, Supplier<Boolean> isCanceled) {
+	}
+
+	@Override
+	public void notifyCheckpointComplete(long checkpointId, OperatorChain<?, ?> operatorChain, Supplier<Boolean> isRunning) {
+	}
+
+	@Override
+	public void notifyCheckpointAborted(long checkpointId, OperatorChain<?, ?> operatorChain, Supplier<Boolean> isRunning) {
+	}
+
+	@Override
+	public void close() {
+	}
+}