You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2023/01/23 15:47:27 UTC

[flink] 03/03: [FLINK-26803][checkpoint] Merge the channel state files

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 8be94e6663d8ac6e3d74bf4cd5f540cc96c8289e
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Tue Nov 8 12:14:15 2022 +0800

    [FLINK-26803][checkpoint] Merge the channel state files
    
    fixup! Address comments
    
    fixup! fixup! Address comments
    
    1. Add some job docs
    2. Using the lock to refactor the ChannelStateWriteRequestExecutorImpl
    
    fixup! fixup! fixup! Address comments
---
 .../execution_checkpointing_configuration.html     |   6 +
 .../state/api/runtime/SavepointEnvironment.java    |   9 +
 .../checkpoint/CheckpointFailureManager.java       |   1 +
 .../checkpoint/CheckpointFailureReason.java        |   4 +
 .../channel/ChannelStateCheckpointWriter.java      | 416 ++++++++++++---------
 .../channel/ChannelStatePendingResult.java         | 192 ++++++++++
 .../channel/ChannelStateWriteRequest.java          | 229 +++++++++---
 .../ChannelStateWriteRequestDispatcherImpl.java    | 176 ++++++---
 .../channel/ChannelStateWriteRequestExecutor.java  |  14 +-
 .../ChannelStateWriteRequestExecutorFactory.java   |  67 ++++
 .../ChannelStateWriteRequestExecutorImpl.java      | 311 +++++++++++++--
 .../checkpoint/channel/ChannelStateWriterImpl.java |  85 +++--
 .../flink/runtime/execution/Environment.java       |   3 +
 .../io/network/logger/NetworkActionsLogger.java    |   5 +-
 ...ExecutorChannelStateExecutorFactoryManager.java | 106 ++++++
 .../flink/runtime/taskexecutor/TaskExecutor.java   |  14 +-
 .../runtime/taskexecutor/TaskManagerServices.java  |  12 +
 .../runtime/taskmanager/RuntimeEnvironment.java    |  12 +-
 .../org/apache/flink/runtime/taskmanager/Task.java |  11 +-
 .../channel/ChannelStateCheckpointWriterTest.java  | 257 +++++++++++--
 ...ChannelStateWriteRequestDispatcherImplTest.java | 176 +++++++--
 .../ChannelStateWriteRequestDispatcherTest.java    |  35 +-
 ...hannelStateWriteRequestExecutorFactoryTest.java |  72 ++++
 .../ChannelStateWriteRequestExecutorImplTest.java  | 326 ++++++++++++++--
 .../channel/ChannelStateWriteResultUtil.java       |  75 ++++
 .../channel/ChannelStateWriterImplTest.java        | 113 +++---
 .../channel/CheckpointInProgressRequestTest.java   |   4 +
 .../operators/testutils/DummyEnvironment.java      |   8 +
 .../operators/testutils/MockEnvironment.java       |  12 +-
 .../testutils/MockEnvironmentBuilder.java          |  12 +-
 .../runtime/state/ChannelPersistenceITCase.java    |  41 +-
 ...utorChannelStateExecutorFactoryManagerTest.java |  78 ++++
 .../taskexecutor/TaskManagerServicesBuilder.java   |  11 +
 .../runtime/taskmanager/TaskAsyncCallTest.java     |   4 +-
 .../flink/runtime/taskmanager/TestTaskBuilder.java |   4 +-
 .../runtime/util/JvmExitOnFatalErrorTest.java      |   5 +-
 .../api/environment/CheckpointConfig.java          |  27 ++
 .../environment/ExecutionCheckpointingOptions.java |   9 +
 .../flink/streaming/api/graph/StreamConfig.java    |  11 +
 .../api/graph/StreamingJobGraphGenerator.java      |   1 +
 .../flink/streaming/runtime/tasks/StreamTask.java  |   4 +-
 .../tasks/SubtaskCheckpointCoordinatorImpl.java    |  29 +-
 .../tasks/InterruptSensitiveRestoreTest.java       |   4 +-
 .../MockSubtaskCheckpointCoordinatorBuilder.java   |  21 +-
 .../runtime/tasks/StreamMockEnvironment.java       |   9 +
 .../runtime/tasks/StreamTaskSystemExitTest.java    |   4 +-
 .../runtime/tasks/StreamTaskTerminationTest.java   |   4 +-
 .../tasks/SubtaskCheckpointCoordinatorTest.java    |  28 +-
 .../runtime/tasks/SynchronousCheckpointITCase.java |   4 +-
 .../tasks/TaskCheckpointingBehaviourTest.java      |   4 +-
 .../test/state/ChangelogRecoveryCachingITCase.java |   4 +
 51 files changed, 2497 insertions(+), 572 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html b/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
index 1889bd7e9f1..fcecef43791 100644
--- a/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
+++ b/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
@@ -14,6 +14,12 @@
             <td>Duration</td>
             <td>Only relevant if <code class="highlighter-rouge">execution.checkpointing.unaligned.enabled</code> is enabled.<br /><br />If timeout is 0, checkpoints will always start unaligned.<br /><br />If timeout has a positive value, checkpoints will start aligned. If during checkpointing, checkpoint start delay exceeds this timeout, alignment will timeout and checkpoint barrier will start working as unaligned checkpoint.</td>
         </tr>
+        <tr>
+            <td><h5>execution.checkpointing.unaligned.max-subtasks-per-channel-state-file</h5></td>
+            <td style="word-wrap: break-word;">5</td>
+            <td>Integer</td>
+            <td>Defines the maximum number of subtasks that share the same channel state file. It can reduce the number of small files when enable unaligned checkpoint. Each subtask will create a new channel state file when this is configured to 1.</td>
+        </tr>
         <tr>
             <td><h5>execution.checkpointing.checkpoints-after-tasks-finish.enabled</h5></td>
             <td style="word-wrap: break-word;">true</td>
diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
index ee230c39d4b..73c17de4104 100644
--- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
+++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.PrioritizedOperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionGraphID;
@@ -107,6 +108,8 @@ public class SavepointEnvironment implements Environment {
 
     private final UserCodeClassLoader userCodeClassLoader;
 
+    private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
     private SavepointEnvironment(
             RuntimeContext ctx,
             Configuration configuration,
@@ -133,6 +136,7 @@ public class SavepointEnvironment implements Environment {
         this.accumulatorRegistry = new AccumulatorRegistry(jobID, attemptID);
 
         this.userCodeClassLoader = UserCodeClassLoaderRuntimeContextAdapter.from(ctx);
+        this.channelStateExecutorFactory = new ChannelStateWriteRequestExecutorFactory(jobID);
     }
 
     @Override
@@ -306,6 +310,11 @@ public class SavepointEnvironment implements Environment {
         throw new UnsupportedOperationException(ERROR_MSG);
     }
 
+    @Override
+    public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+        return channelStateExecutorFactory;
+    }
+
     /** {@link SavepointEnvironment} builder. */
     public static class Builder {
         private RuntimeContext ctx;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
index 4b33a0aab1c..fabff241649 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
@@ -228,6 +228,7 @@ public class CheckpointFailureManager {
             case CHECKPOINT_SUBSUMED:
             case CHECKPOINT_COORDINATOR_SUSPEND:
             case CHECKPOINT_COORDINATOR_SHUTDOWN:
+            case CHANNEL_STATE_SHARED_STREAM_EXCEPTION:
             case JOB_FAILOVER_REGION:
                 // for compatibility purposes with user job behavior
             case CHECKPOINT_DECLINED_TASK_NOT_READY:
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
index e85b0130422..6f56bd67b6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
@@ -36,6 +36,10 @@ public enum CheckpointFailureReason {
 
     CHECKPOINT_ASYNC_EXCEPTION(false, "Asynchronous task checkpoint failed."),
 
+    CHANNEL_STATE_SHARED_STREAM_EXCEPTION(
+            false,
+            "The checkpoint was aborted due to exception of other subtasks sharing the ChannelState file."),
+
     CHECKPOINT_EXPIRED(false, "Checkpoint expired before completing."),
 
     CHECKPOINT_SUBSUMED(false, "Checkpoint has been subsumed."),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
index c858896c5d5..4173bb7140e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
@@ -18,76 +18,72 @@
 package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger;
-import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
 import org.apache.flink.runtime.state.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.InputChannelStateHandle;
-import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.function.RunnableWithException;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nonnull;
 import javax.annotation.concurrent.NotThreadSafe;
 
 import java.io.DataOutputStream;
 import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collection;
 import java.util.HashMap;
-import java.util.List;
+import java.util.HashSet;
 import java.util.Map;
-import java.util.Optional;
-import java.util.concurrent.CompletableFuture;
+import java.util.Objects;
+import java.util.Set;
 
-import static java.util.Collections.emptyList;
-import static java.util.Collections.singletonList;
-import static java.util.UUID.randomUUID;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
 import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.apache.flink.util.ExceptionUtils.rethrow;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
-/** Writes channel state for a specific checkpoint-subtask-attempt triple. */
+/** Writes channel state for multiple subtasks of the same checkpoint. */
 @NotThreadSafe
 class ChannelStateCheckpointWriter {
     private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class);
 
     private final DataOutputStream dataStream;
     private final CheckpointStateOutputStream checkpointStream;
-    private final ChannelStateWriteResult result;
-    private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>();
-    private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets =
-            new HashMap<>();
+
+    /**
+     * Indicates whether the current checkpoints of all subtasks have exception. If it's not null,
+     * the checkpoint will fail.
+     */
+    private Throwable throwable;
+
     private final ChannelStateSerializer serializer;
     private final long checkpointId;
-    private boolean allInputsReceived = false;
-    private boolean allOutputsReceived = false;
     private final RunnableWithException onComplete;
-    private final int subtaskIndex;
-    private final String taskName;
+
+    // Subtasks that have not yet register writer result.
+    private final Set<SubtaskID> subtasksToRegister;
+
+    private final Map<SubtaskID, ChannelStatePendingResult> pendingResults = new HashMap<>();
 
     ChannelStateCheckpointWriter(
-            String taskName,
-            int subtaskIndex,
-            CheckpointStartRequest startCheckpointItem,
+            Set<SubtaskID> subtasks,
+            long checkpointId,
             CheckpointStreamFactory streamFactory,
             ChannelStateSerializer serializer,
             RunnableWithException onComplete)
             throws Exception {
         this(
-                taskName,
-                subtaskIndex,
-                startCheckpointItem.getCheckpointId(),
-                startCheckpointItem.getTargetResult(),
+                subtasks,
+                checkpointId,
                 streamFactory.createCheckpointStateOutputStream(EXCLUSIVE),
                 serializer,
                 onComplete);
@@ -95,38 +91,25 @@ class ChannelStateCheckpointWriter {
 
     @VisibleForTesting
     ChannelStateCheckpointWriter(
-            String taskName,
-            int subtaskIndex,
+            Set<SubtaskID> subtasks,
             long checkpointId,
-            ChannelStateWriteResult result,
             CheckpointStateOutputStream stream,
             ChannelStateSerializer serializer,
             RunnableWithException onComplete) {
-        this(
-                taskName,
-                subtaskIndex,
-                checkpointId,
-                result,
-                serializer,
-                onComplete,
-                stream,
-                new DataOutputStream(stream));
+        this(subtasks, checkpointId, serializer, onComplete, stream, new DataOutputStream(stream));
     }
 
     @VisibleForTesting
     ChannelStateCheckpointWriter(
-            String taskName,
-            int subtaskIndex,
+            Set<SubtaskID> subtasks,
             long checkpointId,
-            ChannelStateWriteResult result,
             ChannelStateSerializer serializer,
             RunnableWithException onComplete,
             CheckpointStateOutputStream checkpointStateOutputStream,
             DataOutputStream dataStream) {
-        this.taskName = taskName;
-        this.subtaskIndex = subtaskIndex;
+        checkArgument(!subtasks.isEmpty(), "The subtasks cannot be empty.");
+        this.subtasksToRegister = new HashSet<>(subtasks);
         this.checkpointId = checkpointId;
-        this.result = checkNotNull(result);
         this.checkpointStream = checkNotNull(checkpointStateOutputStream);
         this.serializer = checkNotNull(serializer);
         this.dataStream = checkNotNull(dataStream);
@@ -134,164 +117,160 @@ class ChannelStateCheckpointWriter {
         runWithChecks(() -> serializer.writeHeader(dataStream));
     }
 
-    void writeInput(InputChannelInfo info, Buffer buffer) {
-        write(
-                inputChannelOffsets,
-                info,
-                buffer,
-                !allInputsReceived,
-                "ChannelStateCheckpointWriter#writeInput");
+    void registerSubtaskResult(
+            SubtaskID subtaskID, ChannelStateWriter.ChannelStateWriteResult result) {
+        // The writer shouldn't register any subtask after writer has exception or is done,
+        checkState(!isDone(), "The write is done.");
+        Preconditions.checkState(
+                !pendingResults.containsKey(subtaskID),
+                "The subtask %s has already been register before.",
+                subtaskID);
+        subtasksToRegister.remove(subtaskID);
+
+        ChannelStatePendingResult pendingResult =
+                new ChannelStatePendingResult(
+                        subtaskID.getSubtaskIndex(), checkpointId, result, serializer);
+        pendingResults.put(subtaskID, pendingResult);
     }
 
-    void writeOutput(ResultSubpartitionInfo info, Buffer buffer) {
-        write(
-                resultSubpartitionOffsets,
-                info,
-                buffer,
-                !allOutputsReceived,
-                "ChannelStateCheckpointWriter#writeOutput");
+    void releaseSubtask(SubtaskID subtaskID) throws Exception {
+        if (subtasksToRegister.remove(subtaskID)) {
+            // If all checkpoint of other subtasks of this writer are completed, and
+            // writer is waiting for the last subtask. After the last subtask is finished,
+            // the writer should be completed.
+            tryFinishResult();
+        }
     }
 
-    private <K> void write(
-            Map<K, StateContentMetaInfo> offsets,
-            K key,
-            Buffer buffer,
-            boolean precondition,
-            String action) {
+    void writeInput(
+            JobVertexID jobVertexID, int subtaskIndex, InputChannelInfo info, Buffer buffer) {
         try {
-            if (result.isDone()) {
+            if (isDone()) {
                 return;
             }
-            runWithChecks(
-                    () -> {
-                        checkState(precondition);
-                        long offset = checkpointStream.getPos();
-                        try (AutoCloseable ignored =
-                                NetworkActionsLogger.measureIO(action, buffer)) {
-                            serializer.writeData(dataStream, buffer);
-                        }
-                        long size = checkpointStream.getPos() - offset;
-                        offsets.computeIfAbsent(key, unused -> new StateContentMetaInfo())
-                                .withDataAdded(offset, size);
-                        NetworkActionsLogger.tracePersist(
-                                action, buffer, taskName, key, checkpointId);
-                    });
+            ChannelStatePendingResult pendingResult =
+                    getChannelStatePendingResult(jobVertexID, subtaskIndex);
+            write(
+                    pendingResult.getInputChannelOffsets(),
+                    info,
+                    buffer,
+                    !pendingResult.isAllInputsReceived(),
+                    "ChannelStateCheckpointWriter#writeInput");
         } finally {
             buffer.recycleBuffer();
         }
     }
 
-    void completeInput() throws Exception {
-        LOG.debug("complete input, output completed: {}", allOutputsReceived);
-        complete(!allInputsReceived, () -> allInputsReceived = true);
+    void writeOutput(
+            JobVertexID jobVertexID, int subtaskIndex, ResultSubpartitionInfo info, Buffer buffer) {
+        try {
+            if (isDone()) {
+                return;
+            }
+            ChannelStatePendingResult pendingResult =
+                    getChannelStatePendingResult(jobVertexID, subtaskIndex);
+            write(
+                    pendingResult.getResultSubpartitionOffsets(),
+                    info,
+                    buffer,
+                    !pendingResult.isAllOutputsReceived(),
+                    "ChannelStateCheckpointWriter#writeOutput");
+        } finally {
+            buffer.recycleBuffer();
+        }
     }
 
-    void completeOutput() throws Exception {
-        LOG.debug("complete output, input completed: {}", allInputsReceived);
-        complete(!allOutputsReceived, () -> allOutputsReceived = true);
+    private <K> void write(
+            Map<K, StateContentMetaInfo> offsets,
+            K key,
+            Buffer buffer,
+            boolean precondition,
+            String action) {
+        runWithChecks(
+                () -> {
+                    checkState(precondition);
+                    long offset = checkpointStream.getPos();
+                    try (AutoCloseable ignored = NetworkActionsLogger.measureIO(action, buffer)) {
+                        serializer.writeData(dataStream, buffer);
+                    }
+                    long size = checkpointStream.getPos() - offset;
+                    offsets.computeIfAbsent(key, unused -> new StateContentMetaInfo())
+                            .withDataAdded(offset, size);
+                    NetworkActionsLogger.tracePersist(action, buffer, key, checkpointId);
+                });
     }
 
-    private void complete(boolean precondition, RunnableWithException complete) throws Exception {
-        if (result.isDone()) {
-            // likely after abort - only need to set the flag run onComplete callback
-            doComplete(precondition, complete, onComplete);
-        } else {
-            runWithChecks(
-                    () ->
-                            doComplete(
-                                    precondition,
-                                    complete,
-                                    onComplete,
-                                    this::finishWriteAndResult));
+    void completeInput(JobVertexID jobVertexID, int subtaskIndex) throws Exception {
+        if (isDone()) {
+            return;
         }
+        getChannelStatePendingResult(jobVertexID, subtaskIndex).completeInput();
+        tryFinishResult();
     }
 
-    private void finishWriteAndResult() throws IOException {
-        if (inputChannelOffsets.isEmpty() && resultSubpartitionOffsets.isEmpty()) {
-            dataStream.close();
-            result.inputChannelStateHandles.complete(emptyList());
-            result.resultSubpartitionStateHandles.complete(emptyList());
+    void completeOutput(JobVertexID jobVertexID, int subtaskIndex) throws Exception {
+        if (isDone()) {
             return;
         }
-        dataStream.flush();
-        StreamStateHandle underlying = checkpointStream.closeAndGetHandle();
-        complete(
-                underlying,
-                result.inputChannelStateHandles,
-                inputChannelOffsets,
-                HandleFactory.INPUT_CHANNEL);
-        complete(
-                underlying,
-                result.resultSubpartitionStateHandles,
-                resultSubpartitionOffsets,
-                HandleFactory.RESULT_SUBPARTITION);
+        getChannelStatePendingResult(jobVertexID, subtaskIndex).completeOutput();
+        tryFinishResult();
     }
 
-    private void doComplete(
-            boolean precondition,
-            RunnableWithException complete,
-            RunnableWithException... callbacks)
-            throws Exception {
-        Preconditions.checkArgument(precondition);
-        complete.run();
-        if (allInputsReceived && allOutputsReceived) {
-            for (RunnableWithException callback : callbacks) {
-                callback.run();
+    public void tryFinishResult() throws Exception {
+        if (!subtasksToRegister.isEmpty()) {
+            // Some subtasks are not registered yet
+            return;
+        }
+        for (ChannelStatePendingResult result : pendingResults.values()) {
+            if (result.isAllInputsReceived() && result.isAllOutputsReceived()) {
+                continue;
             }
+            // Some subtasks did not receive all buffers
+            return;
         }
-    }
 
-    private <I, H extends AbstractChannelStateHandle<I>> void complete(
-            StreamStateHandle underlying,
-            CompletableFuture<Collection<H>> future,
-            Map<I, StateContentMetaInfo> offsets,
-            HandleFactory<I, H> handleFactory)
-            throws IOException {
-        final Collection<H> handles = new ArrayList<>();
-        for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
-            handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
+        if (isDone()) {
+            // likely after abort - only need to set the flag run onComplete callback
+            doComplete(onComplete);
+        } else {
+            runWithChecks(() -> doComplete(onComplete, this::finishWriteAndResult));
         }
-        future.complete(handles);
-        LOG.debug(
-                "channel state write completed, checkpointId: {}, handles: {}",
-                checkpointId,
-                handles);
     }
 
-    private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
-            HandleFactory<I, H> handleFactory,
-            StreamStateHandle underlying,
-            I channelInfo,
-            StateContentMetaInfo contentMetaInfo)
-            throws IOException {
-        Optional<byte[]> bytes =
-                underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and
-        // removing this method:
-        // https://issues.apache.org/jira/browse/FLINK-17972
-        if (bytes.isPresent()) {
-            StreamStateHandle extracted =
-                    new ByteStreamStateHandle(
-                            randomUUID().toString(),
-                            serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
-            return handleFactory.create(
-                    subtaskIndex,
-                    channelInfo,
-                    extracted,
-                    singletonList(serializer.getHeaderLength()),
-                    extracted.getStateSize());
+    private void finishWriteAndResult() throws IOException {
+        StreamStateHandle stateHandle = null;
+        if (checkpointStream.getPos() == serializer.getHeaderLength()) {
+            dataStream.close();
         } else {
-            return handleFactory.create(
-                    subtaskIndex,
-                    channelInfo,
-                    underlying,
-                    contentMetaInfo.getOffsets(),
-                    contentMetaInfo.getSize());
+            dataStream.flush();
+            stateHandle = checkpointStream.closeAndGetHandle();
+        }
+        for (ChannelStatePendingResult result : pendingResults.values()) {
+            result.finishResult(stateHandle);
+        }
+    }
+
+    private void doComplete(RunnableWithException... callbacks) throws Exception {
+        for (RunnableWithException callback : callbacks) {
+            callback.run();
+        }
+    }
+
+    public boolean isDone() {
+        if (throwable != null) {
+            return true;
+        }
+        for (ChannelStatePendingResult result : pendingResults.values()) {
+            if (result.isDone()) {
+                return true;
+            }
         }
+        return false;
     }
 
     private void runWithChecks(RunnableWithException r) {
         try {
-            checkState(!result.isDone(), "result is already completed", result);
+            checkState(!isDone(), "results are already completed", pendingResults.values());
             r.run();
         } catch (Exception e) {
             fail(e);
@@ -301,8 +280,38 @@ class ChannelStateCheckpointWriter {
         }
     }
 
-    public void fail(Throwable e) {
-        result.fail(e);
+    /**
+     * The throwable is just used for specific subtask that triggered the failure. Other subtasks
+     * should fail by {@link CHANNEL_STATE_SHARED_STREAM_EXCEPTION}.
+     */
+    public void fail(JobVertexID jobVertexID, int subtaskIndex, Throwable throwable) {
+        if (isDone()) {
+            return;
+        }
+        this.throwable = throwable;
+
+        ChannelStatePendingResult result =
+                pendingResults.get(SubtaskID.of(jobVertexID, subtaskIndex));
+        if (result != null) {
+            result.fail(throwable);
+        }
+        failResultAndCloseStream(
+                new CheckpointException(CHANNEL_STATE_SHARED_STREAM_EXCEPTION, throwable));
+    }
+
+    public void fail(Throwable throwable) {
+        if (isDone()) {
+            return;
+        }
+        this.throwable = throwable;
+
+        failResultAndCloseStream(throwable);
+    }
+
+    public void failResultAndCloseStream(Throwable e) {
+        for (ChannelStatePendingResult result : pendingResults.values()) {
+            result.fail(e);
+        }
         try {
             checkpointStream.close();
         } catch (Exception closeException) {
@@ -315,18 +324,59 @@ class ChannelStateCheckpointWriter {
         }
     }
 
-    private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
-        H create(
-                int subtaskIndex,
-                I info,
-                StreamStateHandle underlying,
-                List<Long> offsets,
-                long size);
+    @Nonnull
+    private ChannelStatePendingResult getChannelStatePendingResult(
+            JobVertexID jobVertexID, int subtaskIndex) {
+        SubtaskID subtaskID = SubtaskID.of(jobVertexID, subtaskIndex);
+        ChannelStatePendingResult pendingResult = pendingResults.get(subtaskID);
+        checkNotNull(pendingResult, "The subtask[%s] is not registered yet", subtaskID);
+        return pendingResult;
+    }
+}
+
+/** A identification for subtask. */
+class SubtaskID {
+
+    private final JobVertexID jobVertexID;
+    private final int subtaskIndex;
+
+    private SubtaskID(JobVertexID jobVertexID, int subtaskIndex) {
+        this.jobVertexID = jobVertexID;
+        this.subtaskIndex = subtaskIndex;
+    }
+
+    public JobVertexID getJobVertexID() {
+        return jobVertexID;
+    }
+
+    public int getSubtaskIndex() {
+        return subtaskIndex;
+    }
+
+    public static SubtaskID of(JobVertexID jobVertexID, int subtaskIndex) {
+        return new SubtaskID(jobVertexID, subtaskIndex);
+    }
 
-        HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL =
-                InputChannelStateHandle::new;
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        SubtaskID subtaskID = (SubtaskID) o;
+        return subtaskIndex == subtaskID.subtaskIndex
+                && Objects.equals(jobVertexID, subtaskID.jobVertexID);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(jobVertexID, subtaskIndex);
+    }
 
-        HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION =
-                ResultSubpartitionStateHandle::new;
+    @Override
+    public String toString() {
+        return "SubtaskID{" + "jobVertexID=" + jobVertexID + ", subtaskIndex=" + subtaskIndex + '}';
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java
new file mode 100644
index 00000000000..405ce0f392b
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
+import org.apache.flink.runtime.state.InputChannelStateHandle;
+import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+import static java.util.Collections.singletonList;
+import static java.util.UUID.randomUUID;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The pending result of channel state for a specific checkpoint-subtask. */
+public class ChannelStatePendingResult {
+
+    private static final Logger LOG = LoggerFactory.getLogger(ChannelStatePendingResult.class);
+
+    // Subtask information
+    private final int subtaskIndex;
+
+    private final long checkpointId;
+
+    // Result related
+    private final ChannelStateSerializer serializer;
+    private final ChannelStateWriter.ChannelStateWriteResult result;
+    private final Map<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo>
+            inputChannelOffsets = new HashMap<>();
+    private final Map<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo>
+            resultSubpartitionOffsets = new HashMap<>();
+    private boolean allInputsReceived = false;
+    private boolean allOutputsReceived = false;
+
+    public ChannelStatePendingResult(
+            int subtaskIndex,
+            long checkpointId,
+            ChannelStateWriter.ChannelStateWriteResult result,
+            ChannelStateSerializer serializer) {
+        this.subtaskIndex = subtaskIndex;
+        this.checkpointId = checkpointId;
+        this.result = result;
+        this.serializer = serializer;
+    }
+
+    public boolean isAllInputsReceived() {
+        return allInputsReceived;
+    }
+
+    public boolean isAllOutputsReceived() {
+        return allOutputsReceived;
+    }
+
+    public Map<InputChannelInfo, StateContentMetaInfo> getInputChannelOffsets() {
+        return inputChannelOffsets;
+    }
+
+    public Map<ResultSubpartitionInfo, StateContentMetaInfo> getResultSubpartitionOffsets() {
+        return resultSubpartitionOffsets;
+    }
+
+    void completeInput() {
+        LOG.debug("complete input, output completed: {}", allOutputsReceived);
+        checkArgument(!allInputsReceived);
+        allInputsReceived = true;
+    }
+
+    void completeOutput() {
+        LOG.debug("complete output, input completed: {}", allInputsReceived);
+        checkArgument(!allOutputsReceived);
+        allOutputsReceived = true;
+    }
+
+    public void finishResult(@Nullable StreamStateHandle stateHandle) throws IOException {
+        checkState(
+                stateHandle != null
+                        || (inputChannelOffsets.isEmpty() && resultSubpartitionOffsets.isEmpty()),
+                "The stateHandle just can be null when no data is written.");
+        complete(
+                stateHandle,
+                result.inputChannelStateHandles,
+                inputChannelOffsets,
+                HandleFactory.INPUT_CHANNEL);
+        complete(
+                stateHandle,
+                result.resultSubpartitionStateHandles,
+                resultSubpartitionOffsets,
+                HandleFactory.RESULT_SUBPARTITION);
+    }
+
+    private <I, H extends AbstractChannelStateHandle<I>> void complete(
+            StreamStateHandle underlying,
+            CompletableFuture<Collection<H>> future,
+            Map<I, StateContentMetaInfo> offsets,
+            HandleFactory<I, H> handleFactory)
+            throws IOException {
+        final Collection<H> handles = new ArrayList<>();
+        for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
+            handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
+        }
+        future.complete(handles);
+        LOG.debug(
+                "channel state write completed, checkpointId: {}, handles: {}",
+                checkpointId,
+                handles);
+    }
+
+    private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
+            HandleFactory<I, H> handleFactory,
+            StreamStateHandle underlying,
+            I channelInfo,
+            StateContentMetaInfo contentMetaInfo)
+            throws IOException {
+        Optional<byte[]> bytes =
+                underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and
+        // removing this method:
+        // https://issues.apache.org/jira/browse/FLINK-17972
+        if (bytes.isPresent()) {
+            StreamStateHandle extracted =
+                    new ByteStreamStateHandle(
+                            randomUUID().toString(),
+                            serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
+            return handleFactory.create(
+                    subtaskIndex,
+                    channelInfo,
+                    extracted,
+                    singletonList(serializer.getHeaderLength()),
+                    extracted.getStateSize());
+        } else {
+            return handleFactory.create(
+                    subtaskIndex,
+                    channelInfo,
+                    underlying,
+                    contentMetaInfo.getOffsets(),
+                    contentMetaInfo.getSize());
+        }
+    }
+
+    public void fail(Throwable e) {
+        result.fail(e);
+    }
+
+    public boolean isDone() {
+        return this.result.isDone();
+    }
+
+    private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
+        H create(
+                int subtaskIndex,
+                I info,
+                StreamStateHandle underlying,
+                List<Long> offsets,
+                long size);
+
+        HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL =
+                InputChannelStateHandle::new;
+
+        HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION =
+                ResultSubpartitionStateHandle::new;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
index 3fcd8975c61..abef241c325 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
@@ -18,7 +18,9 @@
 package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
+import org.apache.flink.runtime.io.AvailabilityProvider;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.Preconditions;
@@ -27,6 +29,8 @@ import org.apache.flink.util.function.ThrowingConsumer;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nullable;
+
 import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
@@ -41,69 +45,151 @@ import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRe
 import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+abstract class ChannelStateWriteRequest {
+
+    private static final Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class);
+
+    private final JobVertexID jobVertexID;
+
+    private final int subtaskIndex;
 
-interface ChannelStateWriteRequest {
+    private final long checkpointId;
+
+    private final String name;
 
-    Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class);
+    public ChannelStateWriteRequest(
+            JobVertexID jobVertexID, int subtaskIndex, long checkpointId, String name) {
+        this.jobVertexID = jobVertexID;
+        this.subtaskIndex = subtaskIndex;
+        this.checkpointId = checkpointId;
+        this.name = name;
+    }
+
+    public final JobVertexID getJobVertexID() {
+        return jobVertexID;
+    }
+
+    public final int getSubtaskIndex() {
+        return subtaskIndex;
+    }
+
+    public final long getCheckpointId() {
+        return checkpointId;
+    }
+
+    /**
+     * It means whether the request is ready, e.g: some requests write the channel state data
+     * future, the data future may be not ready.
+     *
+     * <p>The ready future is used for {@link ChannelStateWriteRequestExecutorImpl}, executor will
+     * process ready requests first to avoid deadlock.
+     */
+    public CompletableFuture<?> getReadyFuture() {
+        return AvailabilityProvider.AVAILABLE;
+    }
 
-    long getCheckpointId();
+    @Override
+    public String toString() {
+        return name
+                + " {jobVertexID="
+                + jobVertexID
+                + ", subtaskIndex="
+                + subtaskIndex
+                + ", checkpointId="
+                + checkpointId
+                + '}';
+    }
 
-    void cancel(Throwable cause) throws Exception;
+    abstract void cancel(Throwable cause) throws Exception;
 
-    static CheckpointInProgressRequest completeInput(long checkpointId) {
+    static CheckpointInProgressRequest completeInput(
+            JobVertexID jobVertexID, int subtaskIndex, long checkpointId) {
         return new CheckpointInProgressRequest(
-                "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput);
+                "completeInput",
+                jobVertexID,
+                subtaskIndex,
+                checkpointId,
+                writer -> writer.completeInput(jobVertexID, subtaskIndex));
     }
 
-    static CheckpointInProgressRequest completeOutput(long checkpointId) {
+    static CheckpointInProgressRequest completeOutput(
+            JobVertexID jobVertexID, int subtaskIndex, long checkpointId) {
         return new CheckpointInProgressRequest(
-                "completeOutput", checkpointId, ChannelStateCheckpointWriter::completeOutput);
+                "completeOutput",
+                jobVertexID,
+                subtaskIndex,
+                checkpointId,
+                writer -> writer.completeOutput(jobVertexID, subtaskIndex));
     }
 
     static ChannelStateWriteRequest write(
-            long checkpointId, InputChannelInfo info, CloseableIterator<Buffer> iterator) {
+            JobVertexID jobVertexID,
+            int subtaskIndex,
+            long checkpointId,
+            InputChannelInfo info,
+            CloseableIterator<Buffer> iterator) {
         return buildWriteRequest(
+                jobVertexID,
+                subtaskIndex,
                 checkpointId,
                 "writeInput",
                 iterator,
-                (writer, buffer) -> writer.writeInput(info, buffer));
+                (writer, buffer) -> writer.writeInput(jobVertexID, subtaskIndex, info, buffer));
     }
 
     static ChannelStateWriteRequest write(
-            long checkpointId, ResultSubpartitionInfo info, Buffer... buffers) {
+            JobVertexID jobVertexID,
+            int subtaskIndex,
+            long checkpointId,
+            ResultSubpartitionInfo info,
+            Buffer... buffers) {
         return buildWriteRequest(
+                jobVertexID,
+                subtaskIndex,
                 checkpointId,
                 "writeOutput",
                 ofElements(Buffer::recycleBuffer, buffers),
-                (writer, buffer) -> writer.writeOutput(info, buffer));
+                (writer, buffer) -> writer.writeOutput(jobVertexID, subtaskIndex, info, buffer));
     }
 
     static ChannelStateWriteRequest write(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             ResultSubpartitionInfo info,
             CompletableFuture<List<Buffer>> dataFuture) {
         return buildFutureWriteRequest(
+                jobVertexID,
+                subtaskIndex,
                 checkpointId,
                 "writeOutputFuture",
                 dataFuture,
-                (writer, buffer) -> writer.writeOutput(info, buffer));
+                (writer, buffer) -> writer.writeOutput(jobVertexID, subtaskIndex, info, buffer));
     }
 
     static ChannelStateWriteRequest buildFutureWriteRequest(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             String name,
             CompletableFuture<List<Buffer>> dataFuture,
             BiConsumer<ChannelStateCheckpointWriter, Buffer> bufferConsumer) {
         return new CheckpointInProgressRequest(
                 name,
+                jobVertexID,
+                subtaskIndex,
                 checkpointId,
                 writer -> {
+                    checkState(
+                            dataFuture.isDone(), "It should be executed when dataFuture is done.");
                     List<Buffer> buffers;
                     try {
                         buffers = dataFuture.get();
                     } catch (ExecutionException e) {
                         // If dataFuture fails, fail only the single related writer
-                        writer.fail(e);
+                        writer.fail(jobVertexID, subtaskIndex, e);
                         return;
                     }
                     for (Buffer buffer : buffers) {
@@ -122,16 +208,21 @@ interface ChannelStateWriteRequest {
                                                 "Failed to recycle the output buffer of channel state.",
                                                 e);
                                     }
-                                }));
+                                }),
+                dataFuture);
     }
 
     static ChannelStateWriteRequest buildWriteRequest(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             String name,
             CloseableIterator<Buffer> iterator,
             BiConsumer<ChannelStateCheckpointWriter, Buffer> bufferConsumer) {
         return new CheckpointInProgressRequest(
                 name,
+                jobVertexID,
+                subtaskIndex,
                 checkpointId,
                 writer -> {
                     while (iterator.hasNext()) {
@@ -153,14 +244,26 @@ interface ChannelStateWriteRequest {
     }
 
     static ChannelStateWriteRequest start(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             ChannelStateWriteResult targetResult,
             CheckpointStorageLocationReference locationReference) {
-        return new CheckpointStartRequest(checkpointId, targetResult, locationReference);
+        return new CheckpointStartRequest(
+                jobVertexID, subtaskIndex, checkpointId, targetResult, locationReference);
+    }
+
+    static ChannelStateWriteRequest abort(
+            JobVertexID jobVertexID, int subtaskIndex, long checkpointId, Throwable cause) {
+        return new CheckpointAbortRequest(jobVertexID, subtaskIndex, checkpointId, cause);
     }
 
-    static ChannelStateWriteRequest abort(long checkpointId, Throwable cause) {
-        return new CheckpointAbortRequest(checkpointId, cause);
+    static ChannelStateWriteRequest registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+        return new SubtaskRegisterRequest(jobVertexID, subtaskIndex);
+    }
+
+    static ChannelStateWriteRequest releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+        return new SubtaskReleaseRequest(jobVertexID, subtaskIndex);
     }
 
     static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) {
@@ -172,25 +275,22 @@ interface ChannelStateWriteRequest {
     }
 }
 
-final class CheckpointStartRequest implements ChannelStateWriteRequest {
+final class CheckpointStartRequest extends ChannelStateWriteRequest {
+
     private final ChannelStateWriteResult targetResult;
     private final CheckpointStorageLocationReference locationReference;
-    private final long checkpointId;
 
     CheckpointStartRequest(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             ChannelStateWriteResult targetResult,
             CheckpointStorageLocationReference locationReference) {
-        this.checkpointId = checkpointId;
+        super(jobVertexID, subtaskIndex, checkpointId, "Start");
         this.targetResult = checkNotNull(targetResult);
         this.locationReference = checkNotNull(locationReference);
     }
 
-    @Override
-    public long getCheckpointId() {
-        return checkpointId;
-    }
-
     ChannelStateWriteResult getTargetResult() {
         return targetResult;
     }
@@ -203,11 +303,6 @@ final class CheckpointStartRequest implements ChannelStateWriteRequest {
     public void cancel(Throwable cause) {
         targetResult.fail(cause);
     }
-
-    @Override
-    public String toString() {
-        return "start " + checkpointId;
-    }
 }
 
 enum CheckpointInProgressRequestState {
@@ -218,35 +313,44 @@ enum CheckpointInProgressRequestState {
     CANCELLED
 }
 
-final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
+final class CheckpointInProgressRequest extends ChannelStateWriteRequest {
     private final ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action;
     private final ThrowingConsumer<Throwable, Exception> discardAction;
-    private final long checkpointId;
-    private final String name;
     private final AtomicReference<CheckpointInProgressRequestState> state =
             new AtomicReference<>(NEW);
+    @Nullable private final CompletableFuture<?> readyFuture;
 
     CheckpointInProgressRequest(
             String name,
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action) {
-        this(name, checkpointId, action, unused -> {});
+        this(name, jobVertexID, subtaskIndex, checkpointId, action, unused -> {});
     }
 
     CheckpointInProgressRequest(
             String name,
+            JobVertexID jobVertexID,
+            int subtaskIndex,
             long checkpointId,
             ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
             ThrowingConsumer<Throwable, Exception> discardAction) {
-        this.checkpointId = checkpointId;
-        this.action = checkNotNull(action);
-        this.discardAction = checkNotNull(discardAction);
-        this.name = checkNotNull(name);
+        this(name, jobVertexID, subtaskIndex, checkpointId, action, discardAction, null);
     }
 
-    @Override
-    public long getCheckpointId() {
-        return checkpointId;
+    CheckpointInProgressRequest(
+            String name,
+            JobVertexID jobVertexID,
+            int subtaskIndex,
+            long checkpointId,
+            ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
+            ThrowingConsumer<Throwable, Exception> discardAction,
+            @Nullable CompletableFuture<?> readyFuture) {
+        super(jobVertexID, subtaskIndex, checkpointId, name);
+        this.action = checkNotNull(action);
+        this.discardAction = checkNotNull(discardAction);
+        this.readyFuture = readyFuture;
     }
 
     @Override
@@ -268,19 +372,21 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
     }
 
     @Override
-    public String toString() {
-        return name + " " + checkpointId;
+    public CompletableFuture<?> getReadyFuture() {
+        if (readyFuture != null) {
+            return readyFuture;
+        }
+        return super.getReadyFuture();
     }
 }
 
-final class CheckpointAbortRequest implements ChannelStateWriteRequest {
-
-    private final long checkpointId;
+final class CheckpointAbortRequest extends ChannelStateWriteRequest {
 
     private final Throwable throwable;
 
-    public CheckpointAbortRequest(long checkpointId, Throwable throwable) {
-        this.checkpointId = checkpointId;
+    public CheckpointAbortRequest(
+            JobVertexID jobVertexID, int subtaskIndex, long checkpointId, Throwable throwable) {
+        super(jobVertexID, subtaskIndex, checkpointId, "Abort");
         this.throwable = throwable;
     }
 
@@ -289,15 +395,30 @@ final class CheckpointAbortRequest implements ChannelStateWriteRequest {
     }
 
     @Override
-    public long getCheckpointId() {
-        return checkpointId;
+    public void cancel(Throwable cause) throws Exception {}
+
+    @Override
+    public String toString() {
+        return String.format("%s, cause : %s.", super.toString(), throwable);
+    }
+}
+
+final class SubtaskRegisterRequest extends ChannelStateWriteRequest {
+
+    public SubtaskRegisterRequest(JobVertexID jobVertexID, int subtaskIndex) {
+        super(jobVertexID, subtaskIndex, 0, "Register");
     }
 
     @Override
     public void cancel(Throwable cause) throws Exception {}
+}
 
-    @Override
-    public String toString() {
-        return "Abort checkpointId-" + checkpointId;
+final class SubtaskReleaseRequest extends ChannelStateWriteRequest {
+
+    public SubtaskReleaseRequest(JobVertexID jobVertexID, int subtaskIndex) {
+        super(jobVertexID, subtaskIndex, 0, "Release");
     }
+
+    @Override
+    public void cancel(Throwable cause) throws Exception {}
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
index fb2e4bedc79..5151d9be701 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
@@ -17,12 +17,20 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
 import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED;
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -36,10 +44,15 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
     private static final Logger LOG =
             LoggerFactory.getLogger(ChannelStateWriteRequestDispatcherImpl.class);
 
-    private final CheckpointStorageWorkerView streamFactoryResolver;
+    private final CheckpointStorage checkpointStorage;
+
+    private final JobID jobID;
+
     private final ChannelStateSerializer serializer;
-    private final int subtaskIndex;
-    private final String taskName;
+
+    private final Set<SubtaskID> registeredSubtasks;
+
+    private CheckpointStorageWorkerView streamFactoryResolver;
 
     /**
      * It is the checkpointId corresponding to writer. And It should be always update with {@link
@@ -53,6 +66,9 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
      */
     private long maxAbortedCheckpointId;
 
+    /** The aborted subtask of the maxAbortedCheckpointId. */
+    private SubtaskID abortedSubtaskID;
+
     /** The aborted cause of the maxAbortedCheckpointId. */
     private Throwable abortedCause;
 
@@ -62,14 +78,11 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
     private ChannelStateCheckpointWriter writer;
 
     ChannelStateWriteRequestDispatcherImpl(
-            String taskName,
-            int subtaskIndex,
-            CheckpointStorageWorkerView streamFactoryResolver,
-            ChannelStateSerializer serializer) {
-        this.taskName = taskName;
-        this.subtaskIndex = subtaskIndex;
-        this.streamFactoryResolver = checkNotNull(streamFactoryResolver);
+            CheckpointStorage checkpointStorage, JobID jobID, ChannelStateSerializer serializer) {
+        this.checkpointStorage = checkNotNull(checkpointStorage);
+        this.jobID = jobID;
         this.serializer = checkNotNull(serializer);
+        this.registeredSubtasks = new HashSet<>();
         this.ongoingCheckpointId = -1;
         this.maxAbortedCheckpointId = -1;
     }
@@ -90,56 +103,95 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
     }
 
     private void dispatchInternal(ChannelStateWriteRequest request) throws Exception {
-        if (isAbortedCheckpoint(request.getCheckpointId())) {
-            if (request.getCheckpointId() == maxAbortedCheckpointId) {
-                request.cancel(abortedCause);
-            } else {
-                request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+        if (request instanceof SubtaskRegisterRequest) {
+            SubtaskRegisterRequest req = (SubtaskRegisterRequest) request;
+            SubtaskID subtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+            registeredSubtasks.add(subtaskID);
+            return;
+        } else if (request instanceof SubtaskReleaseRequest) {
+            SubtaskReleaseRequest req = (SubtaskReleaseRequest) request;
+            SubtaskID subtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+            registeredSubtasks.remove(subtaskID);
+            if (writer == null) {
+                return;
             }
+            writer.releaseSubtask(subtaskID);
             return;
         }
 
-        if (request instanceof CheckpointStartRequest) {
-            checkState(
-                    request.getCheckpointId() > ongoingCheckpointId,
-                    String.format(
-                            "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
-                            ongoingCheckpointId, request));
-            failAndClearWriter(
-                    new IllegalStateException(
-                            String.format(
-                                    "Task[name=%s, subtaskIndex=%s] has uncompleted channelState writer of checkpointId=%s, "
-                                            + "but it received a new checkpoint start request of checkpointId=%s, it maybe "
-                                            + "a bug due to currently not supported concurrent unaligned checkpoint.",
-                                    taskName,
-                                    subtaskIndex,
-                                    ongoingCheckpointId,
-                                    request.getCheckpointId())));
-            this.writer = buildWriter((CheckpointStartRequest) request);
-            this.ongoingCheckpointId = request.getCheckpointId();
+        if (isAbortedCheckpoint(request.getCheckpointId())) {
+            handleAbortedRequest(request);
+        } else if (request instanceof CheckpointStartRequest) {
+            handleCheckpointStartRequest(request);
         } else if (request instanceof CheckpointInProgressRequest) {
-            CheckpointInProgressRequest req = (CheckpointInProgressRequest) request;
-            checkArgument(
-                    ongoingCheckpointId == req.getCheckpointId() && writer != null,
-                    "writer not found while processing request: " + req);
-            req.execute(writer);
+            handleCheckpointInProgressRequest((CheckpointInProgressRequest) request);
         } else if (request instanceof CheckpointAbortRequest) {
-            CheckpointAbortRequest req = (CheckpointAbortRequest) request;
-            if (request.getCheckpointId() > maxAbortedCheckpointId) {
-                this.maxAbortedCheckpointId = req.getCheckpointId();
-                this.abortedCause = req.getThrowable();
-            }
-
-            if (req.getCheckpointId() == ongoingCheckpointId) {
-                failAndClearWriter(req.getThrowable());
-            } else if (request.getCheckpointId() > ongoingCheckpointId) {
-                failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
-            }
+            handleCheckpointAbortRequest(request);
         } else {
             throw new IllegalArgumentException("unknown request type: " + request);
         }
     }
 
+    private void handleAbortedRequest(ChannelStateWriteRequest request) throws Exception {
+        if (request.getCheckpointId() != maxAbortedCheckpointId) {
+            request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+            return;
+        }
+
+        SubtaskID requestSubtask =
+                SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex());
+        if (requestSubtask.equals(abortedSubtaskID)) {
+            request.cancel(abortedCause);
+        } else {
+            request.cancel(
+                    new CheckpointException(CHANNEL_STATE_SHARED_STREAM_EXCEPTION, abortedCause));
+        }
+    }
+
+    private void handleCheckpointStartRequest(ChannelStateWriteRequest request) throws Exception {
+        checkState(
+                request.getCheckpointId() >= ongoingCheckpointId,
+                String.format(
+                        "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
+                        ongoingCheckpointId, request));
+        if (request.getCheckpointId() > ongoingCheckpointId) {
+            // Clear the previous writer.
+            failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+        }
+        CheckpointStartRequest req = (CheckpointStartRequest) request;
+        // The writer may not be null due to other subtask may have built writer for
+        // ongoingCheckpointId when multiple subtasks share channel state file.
+        if (writer == null) {
+            this.writer = buildWriter(req);
+            this.ongoingCheckpointId = request.getCheckpointId();
+        }
+        writer.registerSubtaskResult(
+                SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex()), req.getTargetResult());
+    }
+
+    private void handleCheckpointInProgressRequest(CheckpointInProgressRequest req)
+            throws Exception {
+        checkArgument(
+                ongoingCheckpointId == req.getCheckpointId() && writer != null,
+                "writer not found while processing request: " + req);
+        req.execute(writer);
+    }
+
+    private void handleCheckpointAbortRequest(ChannelStateWriteRequest request) {
+        CheckpointAbortRequest req = (CheckpointAbortRequest) request;
+        if (request.getCheckpointId() > maxAbortedCheckpointId) {
+            this.maxAbortedCheckpointId = req.getCheckpointId();
+            this.abortedCause = req.getThrowable();
+            this.abortedSubtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+        }
+
+        if (req.getCheckpointId() == ongoingCheckpointId) {
+            failAndClearWriter(req.getJobVertexID(), req.getSubtaskIndex(), req.getThrowable());
+        } else if (request.getCheckpointId() > ongoingCheckpointId) {
+            failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+        }
+    }
+
     private boolean isAbortedCheckpoint(long checkpointId) {
         return checkpointId < ongoingCheckpointId || checkpointId <= maxAbortedCheckpointId;
     }
@@ -152,14 +204,23 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
         writer = null;
     }
 
+    private void failAndClearWriter(
+            JobVertexID jobVertexID, int subtaskIndex, Throwable throwable) {
+        if (writer == null) {
+            return;
+        }
+        writer.fail(jobVertexID, subtaskIndex, throwable);
+        writer = null;
+    }
+
     private ChannelStateCheckpointWriter buildWriter(CheckpointStartRequest request)
             throws Exception {
         return new ChannelStateCheckpointWriter(
-                taskName,
-                subtaskIndex,
-                request,
-                streamFactoryResolver.resolveCheckpointStorageLocation(
-                        request.getCheckpointId(), request.getLocationReference()),
+                registeredSubtasks,
+                request.getCheckpointId(),
+                getStreamFactoryResolver()
+                        .resolveCheckpointStorageLocation(
+                                request.getCheckpointId(), request.getLocationReference()),
                 serializer,
                 () -> {
                     checkState(
@@ -183,4 +244,11 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
         }
         writer = null;
     }
+
+    CheckpointStorageWorkerView getStreamFactoryResolver() throws IOException {
+        if (streamFactoryResolver == null) {
+            streamFactoryResolver = checkpointStorage.createCheckpointStorage(jobID);
+        }
+        return streamFactoryResolver;
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
index 52256873f93..42d0142f112 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
@@ -17,15 +17,17 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
-import java.io.Closeable;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
+import java.io.IOException;
 
 /**
  * Executes {@link ChannelStateWriteRequest}s potentially asynchronously. An exception thrown during
  * the execution should be re-thrown on any next call.
  */
-interface ChannelStateWriteRequestExecutor extends Closeable {
+interface ChannelStateWriteRequestExecutor {
 
-    /** @throws IllegalStateException if called more than once or after {@link #close()} */
+    /** @throws IllegalStateException if called more than once or after {@link #releaseSubtask} */
     void start() throws IllegalStateException;
 
     /**
@@ -45,4 +47,10 @@ interface ChannelStateWriteRequestExecutor extends Closeable {
      * @throws Exception if any exception occurred during processing this or other items previously
      */
     void submitPriority(ChannelStateWriteRequest r) throws Exception;
+
+    /** Register subtask. */
+    void registerSubtask(JobVertexID jobVertexID, int subtaskIndex);
+
+    /** Release the subtask. */
+    void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) throws IOException;
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java
new file mode 100644
index 00000000000..c41e96a24f5
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The factory of {@link ChannelStateWriteRequestExecutor}. */
+public class ChannelStateWriteRequestExecutorFactory {
+
+    private final JobID jobID;
+
+    private final Object lock = new Object();
+
+    @GuardedBy("lock")
+    private ChannelStateWriteRequestExecutor executor;
+
+    public ChannelStateWriteRequestExecutorFactory(JobID jobID) {
+        this.jobID = jobID;
+    }
+
+    public ChannelStateWriteRequestExecutor getOrCreateExecutor(
+            JobVertexID jobVertexID,
+            int subtaskIndex,
+            CheckpointStorage checkpointStorage,
+            int maxSubtasksPerChannelStateFile) {
+        synchronized (lock) {
+            if (executor == null) {
+                executor =
+                        new ChannelStateWriteRequestExecutorImpl(
+                                new ChannelStateWriteRequestDispatcherImpl(
+                                        checkpointStorage, jobID, new ChannelStateSerializerImpl()),
+                                maxSubtasksPerChannelStateFile,
+                                executor -> {
+                                    synchronized (lock) {
+                                        checkState(this.executor == executor);
+                                        this.executor = null;
+                                    }
+                                });
+                executor.start();
+            }
+            ChannelStateWriteRequestExecutor currentExecutor = executor;
+            currentExecutor.registerSubtask(jobVertexID, subtaskIndex);
+            return currentExecutor;
+        }
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
index 5c12a778d64..bafc5b69c8c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
@@ -18,23 +18,37 @@
 package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.fs.FileSystemSafetyNet;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.function.RunnableWithException;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
 import javax.annotation.concurrent.ThreadSafe;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.List;
-import java.util.concurrent.BlockingDeque;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.Set;
 import java.util.concurrent.CancellationException;
-import java.util.concurrent.LinkedBlockingDeque;
+import java.util.function.Consumer;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.util.IOUtils.closeAll;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /**
  * Executes {@link ChannelStateWriteRequest}s in a separate thread. Any exception occurred during
@@ -46,32 +60,59 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
     private static final Logger LOG =
             LoggerFactory.getLogger(ChannelStateWriteRequestExecutorImpl.class);
 
+    private final Object lock = new Object();
+
     private final ChannelStateWriteRequestDispatcher dispatcher;
-    private final BlockingDeque<ChannelStateWriteRequest> deque;
     private final Thread thread;
-    private volatile Exception thrown = null;
-    private volatile boolean wasClosed = false;
-    private final String taskName;
+
+    private final int maxSubtasksPerChannelStateFile;
+
+    @GuardedBy("lock")
+    private final Deque<ChannelStateWriteRequest> deque;
+
+    @GuardedBy("lock")
+    private Exception thrown = null;
+
+    @GuardedBy("lock")
+    private boolean wasClosed = false;
+
+    @GuardedBy("lock")
+    private final Map<SubtaskID, Queue<ChannelStateWriteRequest>> unreadyQueues = new HashMap<>();
+
+    @GuardedBy("lock")
+    private boolean isRegistering = true;
+
+    @GuardedBy("lock")
+    private final Set<SubtaskID> subtasks;
+
+    /** It cannot be called inside the {@link #lock} to avoid the deadlock. */
+    private final Consumer<ChannelStateWriteRequestExecutor> onRegistered;
 
     ChannelStateWriteRequestExecutorImpl(
-            String taskName, ChannelStateWriteRequestDispatcher dispatcher) {
-        this(taskName, dispatcher, new LinkedBlockingDeque<>());
+            ChannelStateWriteRequestDispatcher dispatcher,
+            int maxSubtasksPerChannelStateFile,
+            Consumer<ChannelStateWriteRequestExecutor> onRegistered) {
+        this(dispatcher, new ArrayDeque<>(), maxSubtasksPerChannelStateFile, onRegistered);
     }
 
     ChannelStateWriteRequestExecutorImpl(
-            String taskName,
             ChannelStateWriteRequestDispatcher dispatcher,
-            BlockingDeque<ChannelStateWriteRequest> deque) {
-        this.taskName = taskName;
+            Deque<ChannelStateWriteRequest> deque,
+            int maxSubtasksPerChannelStateFile,
+            Consumer<ChannelStateWriteRequestExecutor> onRegistered) {
         this.dispatcher = dispatcher;
         this.deque = deque;
-        this.thread = new Thread(this::run, "Channel state writer " + taskName);
+        this.maxSubtasksPerChannelStateFile = maxSubtasksPerChannelStateFile;
+        this.onRegistered = onRegistered;
+        this.thread = new Thread(this::run, "Channel state writer ");
+        this.subtasks = new HashSet<>();
         this.thread.setDaemon(true);
     }
 
     @VisibleForTesting
     void run() {
         try {
+            FileSystemSafetyNet.initializeSafetyNetForThread();
             loop();
         } catch (Exception ex) {
             thrown = ex;
@@ -79,39 +120,94 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
             try {
                 closeAll(
                         this::cleanupRequests,
-                        () ->
-                                dispatcher.fail(
-                                        thrown == null ? new CancellationException() : thrown));
+                        () -> {
+                            Throwable cause;
+                            synchronized (lock) {
+                                cause = thrown == null ? new CancellationException() : thrown;
+                            }
+                            dispatcher.fail(cause);
+                        });
             } catch (Exception e) {
-                //noinspection NonAtomicOperationOnVolatileField
-                thrown = ExceptionUtils.firstOrSuppressed(e, thrown);
+                synchronized (lock) {
+                    //noinspection NonAtomicOperationOnVolatileField
+                    thrown = ExceptionUtils.firstOrSuppressed(e, thrown);
+                }
             }
+            FileSystemSafetyNet.closeSafetyNetAndGuardedResourcesForThread();
         }
-        LOG.debug("{} loop terminated", taskName);
+        LOG.debug("loop terminated");
     }
 
     private void loop() throws Exception {
-        while (!wasClosed) {
+        while (true) {
             try {
-                dispatcher.dispatch(deque.take());
+                ChannelStateWriteRequest request;
+                boolean completeRegister = false;
+                synchronized (lock) {
+                    request = waitAndTakeUnsafe();
+                    if (request == null) {
+                        // The executor is closed, so return directly.
+                        return;
+                    }
+                    // The executor will end the registration, when the start request comes.
+                    // Because the checkpoint can be started after all tasks are initiated.
+                    if (request instanceof CheckpointStartRequest) {
+                        completeRegister = completeRegister();
+                    }
+                }
+                if (completeRegister) {
+                    onRegistered.accept(this);
+                }
+                dispatcher.dispatch(request);
             } catch (InterruptedException e) {
-                if (!wasClosed) {
-                    LOG.debug(
-                            taskName
-                                    + " interrupted while waiting for a request (continue waiting)",
-                            e);
-                } else {
-                    Thread.currentThread().interrupt();
+                synchronized (lock) {
+                    if (!wasClosed) {
+                        LOG.debug(
+                                "Channel state executor is interrupted while waiting for a request (continue waiting)",
+                                e);
+                    } else {
+                        Thread.currentThread().interrupt();
+                        return;
+                    }
                 }
             }
         }
     }
 
+    /**
+     * Retrieves and removes the head request of the {@link #deque}, waiting if necessary until an
+     * element becomes available.
+     *
+     * @return The head request, it can be null when the executor is closed.
+     */
+    @Nullable
+    private ChannelStateWriteRequest waitAndTakeUnsafe() throws InterruptedException {
+        ChannelStateWriteRequest request;
+        while (!wasClosed) {
+            request = deque.pollFirst();
+            if (request == null) {
+                lock.wait();
+            } else {
+                return request;
+            }
+        }
+        return null;
+    }
+
     private void cleanupRequests() throws Exception {
-        Throwable cause = thrown == null ? new CancellationException() : thrown;
-        List<ChannelStateWriteRequest> drained = new ArrayList<>();
-        deque.drainTo(drained);
-        LOG.info("{} discarding {} drained requests", taskName, drained.size());
+        List<ChannelStateWriteRequest> drained;
+        Throwable cause;
+        synchronized (lock) {
+            cause = thrown == null ? new CancellationException() : thrown;
+            drained = new ArrayList<>(deque);
+            deque.clear();
+            for (Queue<ChannelStateWriteRequest> unreadyQueue : unreadyQueues.values()) {
+                while (!unreadyQueue.isEmpty()) {
+                    drained.add(unreadyQueue.poll());
+                }
+            }
+        }
+        LOG.info("discarding {} drained requests", drained.size());
         closeAll(
                 drained.stream()
                         .<AutoCloseable>map(request -> () -> request.cancel(cause))
@@ -125,12 +221,91 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
 
     @Override
     public void submit(ChannelStateWriteRequest request) throws Exception {
-        submitInternal(request, () -> deque.add(request));
+        synchronized (lock) {
+            Queue<ChannelStateWriteRequest> unreadyQueue =
+                    unreadyQueues.get(
+                            SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex()));
+            checkArgument(unreadyQueue != null, "The subtask %s is not yet registered.");
+            submitInternal(
+                    request,
+                    () -> {
+                        // 1. unreadyQueue isn't empty, the new request must keep the order, so add
+                        // the new request to the unreadyQueue tail.
+                        if (!unreadyQueue.isEmpty()) {
+                            unreadyQueue.add(request);
+                            return;
+                        }
+                        // 2. unreadyQueue is empty, and new request is ready, so add it to the
+                        // readyQueue directly.
+                        if (request.getReadyFuture().isDone()) {
+                            deque.add(request);
+                            lock.notifyAll();
+                            return;
+                        }
+                        // 3. unreadyQueue is empty, and new request isn't ready, so add it to the
+                        // unreadyQueue, and register it as the first request.
+                        unreadyQueue.add(request);
+                        registerFirstRequestFuture(request, unreadyQueue);
+                    });
+        }
+    }
+
+    private void registerFirstRequestFuture(
+            @Nonnull ChannelStateWriteRequest firstRequest,
+            Queue<ChannelStateWriteRequest> unreadyQueue) {
+        assert Thread.holdsLock(lock);
+        checkState(firstRequest == unreadyQueue.peek(), "The request isn't the first request.");
+
+        firstRequest
+                .getReadyFuture()
+                .thenAccept(
+                        o -> {
+                            synchronized (lock) {
+                                moveReadyRequestToReadyQueue(unreadyQueue, firstRequest);
+                            }
+                        })
+                .exceptionally(
+                        throwable -> {
+                            // When dataFuture is completed, just move the request to readyQueue.
+                            // And the throwable doesn't need to be handled here, it will be handled
+                            // in channel state writer thread later.
+                            synchronized (lock) {
+                                moveReadyRequestToReadyQueue(unreadyQueue, firstRequest);
+                            }
+                            return null;
+                        });
+    }
+
+    private void moveReadyRequestToReadyQueue(
+            Queue<ChannelStateWriteRequest> unreadyQueue, ChannelStateWriteRequest firstRequest) {
+        assert Thread.holdsLock(lock);
+        checkState(firstRequest == unreadyQueue.peek());
+        while (!unreadyQueue.isEmpty()) {
+            ChannelStateWriteRequest req = unreadyQueue.peek();
+            if (!req.getReadyFuture().isDone()) {
+                registerFirstRequestFuture(req, unreadyQueue);
+                return;
+            }
+            deque.add(Objects.requireNonNull(unreadyQueue.poll()));
+            lock.notifyAll();
+        }
     }
 
     @Override
     public void submitPriority(ChannelStateWriteRequest request) throws Exception {
-        submitInternal(request, () -> deque.addFirst(request));
+        synchronized (lock) {
+            checkArgument(
+                    unreadyQueues.containsKey(
+                            SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex())),
+                    "The subtask %s is not yet registered.");
+            checkState(request.getReadyFuture().isDone(), "The priority request must be ready.");
+            submitInternal(
+                    request,
+                    () -> {
+                        deque.addFirst(request);
+                        lock.notifyAll();
+                    });
+        }
     }
 
     private void submitInternal(ChannelStateWriteRequest request, RunnableWithException action)
@@ -145,6 +320,7 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
     }
 
     private void ensureRunning() throws Exception {
+        assert Thread.holdsLock(lock);
         // this check should be performed *at least after* enqueuing a request
         // checking before is not enough because (check + enqueue) is not atomic
         if (wasClosed || !thread.isAlive()) {
@@ -158,8 +334,63 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
     }
 
     @Override
-    public void close() throws IOException {
-        wasClosed = true;
+    public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+        SubtaskID subtaskID = SubtaskID.of(jobVertexID, subtaskIndex);
+        boolean completeRegister = false;
+        synchronized (lock) {
+            checkState(isRegistering(), "This executor has been registered.");
+            checkState(
+                    !subtasks.contains(subtaskID),
+                    String.format("This subtask[%s] has already registered.", subtaskID));
+            subtasks.add(subtaskID);
+            deque.add(
+                    ChannelStateWriteRequest.registerSubtask(
+                            subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex()));
+            lock.notifyAll();
+            unreadyQueues.put(subtaskID, new ArrayDeque<>());
+            if (subtasks.size() == maxSubtasksPerChannelStateFile) {
+                completeRegister = completeRegister();
+            }
+        }
+        if (completeRegister) {
+            onRegistered.accept(this);
+        }
+    }
+
+    @VisibleForTesting
+    public boolean isRegistering() {
+        synchronized (lock) {
+            return isRegistering;
+        }
+    }
+
+    private boolean completeRegister() {
+        assert Thread.holdsLock(lock);
+        if (isRegistering) {
+            isRegistering = false;
+            return true;
+        }
+        return false;
+    }
+
+    @Override
+    public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) throws IOException {
+        boolean completeRegister = false;
+        try {
+            synchronized (lock) {
+                completeRegister = completeRegister();
+                subtasks.remove(SubtaskID.of(jobVertexID, subtaskIndex));
+                if (!subtasks.isEmpty()) {
+                    return;
+                }
+                wasClosed = true;
+                lock.notifyAll();
+            }
+        } finally {
+            if (completeRegister) {
+                onRegistered.accept(this);
+            }
+        }
         while (thread.isAlive()) {
             thread.interrupt();
             try {
@@ -168,11 +399,15 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
                 if (!thread.isAlive()) {
                     Thread.currentThread().interrupt();
                 }
-                LOG.debug(taskName + " interrupted while waiting for the writer thread to die", e);
+                LOG.debug(
+                        "Channel state executor is interrupted while waiting for the writer thread to die",
+                        e);
             }
         }
-        if (thrown != null) {
-            throw new IOException(thrown);
+        synchronized (lock) {
+            if (thrown != null) {
+                throw new IOException(thrown);
+            }
         }
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
index b1091d2b2bf..9b80813cf81 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
@@ -21,8 +21,9 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStateOutputStream;
-import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
+import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.util.CloseableIterator;
 import org.apache.flink.util.Preconditions;
 
@@ -36,6 +37,7 @@ import java.util.List;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput;
@@ -65,18 +67,37 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
             1000; // includes max-concurrent-checkpoints + checkpoints to be aborted (scheduled via
     // mailbox)
 
+    private final JobVertexID jobVertexID;
+
+    private final int subtaskIndex;
+
     private final String taskName;
+
     private final ChannelStateWriteRequestExecutor executor;
     private final ConcurrentMap<Long, ChannelStateWriteResult> results;
     private final int maxCheckpoints;
 
+    private final AtomicBoolean wasClosed = new AtomicBoolean(false);
+
     /**
      * Creates a {@link ChannelStateWriterImpl} with {@link #DEFAULT_MAX_CHECKPOINTS} as {@link
      * #maxCheckpoints}.
      */
     public ChannelStateWriterImpl(
-            String taskName, int subtaskIndex, CheckpointStorageWorkerView streamFactoryResolver) {
-        this(taskName, subtaskIndex, streamFactoryResolver, DEFAULT_MAX_CHECKPOINTS);
+            JobVertexID jobVertexID,
+            String taskName,
+            int subtaskIndex,
+            CheckpointStorage checkpointStorage,
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory,
+            int maxSubtasksPerChannelStateFile) {
+        this(
+                jobVertexID,
+                taskName,
+                subtaskIndex,
+                checkpointStorage,
+                DEFAULT_MAX_CHECKPOINTS,
+                channelStateExecutorFactory,
+                maxSubtasksPerChannelStateFile);
     }
 
     /**
@@ -84,34 +105,41 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
      * {@link ChannelStateSerializer}, and a {@link ChannelStateWriteRequestExecutorImpl}.
      *
      * @param taskName
-     * @param streamFactoryResolver a factory to obtain output stream factory for a given checkpoint
+     * @param checkpointStorage a factory to obtain output stream factory for a given checkpoint
      * @param maxCheckpoints maximum number of checkpoints to be written currently or finished but
      *     not taken yet.
      */
     ChannelStateWriterImpl(
+            JobVertexID jobVertexID,
             String taskName,
             int subtaskIndex,
-            CheckpointStorageWorkerView streamFactoryResolver,
-            int maxCheckpoints) {
+            CheckpointStorage checkpointStorage,
+            int maxCheckpoints,
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory,
+            int maxSubtasksPerChannelStateFile) {
         this(
+                jobVertexID,
                 taskName,
+                subtaskIndex,
                 new ConcurrentHashMap<>(maxCheckpoints),
-                new ChannelStateWriteRequestExecutorImpl(
-                        taskName,
-                        new ChannelStateWriteRequestDispatcherImpl(
-                                taskName,
-                                subtaskIndex,
-                                streamFactoryResolver,
-                                new ChannelStateSerializerImpl())),
+                channelStateExecutorFactory.getOrCreateExecutor(
+                        jobVertexID,
+                        subtaskIndex,
+                        checkpointStorage,
+                        maxSubtasksPerChannelStateFile),
                 maxCheckpoints);
     }
 
     ChannelStateWriterImpl(
+            JobVertexID jobVertexID,
             String taskName,
+            int subtaskIndex,
             ConcurrentMap<Long, ChannelStateWriteResult> results,
             ChannelStateWriteRequestExecutor executor,
             int maxCheckpoints) {
+        this.jobVertexID = jobVertexID;
         this.taskName = taskName;
+        this.subtaskIndex = subtaskIndex;
         this.results = results;
         this.maxCheckpoints = maxCheckpoints;
         this.executor = executor;
@@ -135,6 +163,8 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
                                             maxCheckpoints));
                             enqueue(
                                     new CheckpointStartRequest(
+                                            jobVertexID,
+                                            subtaskIndex,
                                             checkpointId,
                                             result,
                                             checkpointOptions.getTargetLocation()),
@@ -158,7 +188,7 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
                 checkpointId,
                 info,
                 startSeqNum);
-        enqueue(write(checkpointId, info, iterator), false);
+        enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, iterator), false);
     }
 
     @Override
@@ -171,7 +201,7 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
                 info,
                 startSeqNum,
                 data == null ? 0 : data.length);
-        enqueue(write(checkpointId, info, data), false);
+        enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, data), false);
     }
 
     @Override
@@ -187,27 +217,29 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
                 checkpointId,
                 info,
                 startSeqNum);
-        enqueue(write(checkpointId, info, dataFuture), false);
+        enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, dataFuture), false);
     }
 
     @Override
     public void finishInput(long checkpointId) {
         LOG.debug("{} finishing input data, checkpoint {}", taskName, checkpointId);
-        enqueue(completeInput(checkpointId), false);
+        enqueue(completeInput(jobVertexID, subtaskIndex, checkpointId), false);
     }
 
     @Override
     public void finishOutput(long checkpointId) {
         LOG.debug("{} finishing output data, checkpoint {}", taskName, checkpointId);
-        enqueue(completeOutput(checkpointId), false);
+        enqueue(completeOutput(jobVertexID, subtaskIndex, checkpointId), false);
     }
 
     @Override
     public void abort(long checkpointId, Throwable cause, boolean cleanup) {
         LOG.debug("{} aborting, checkpoint {}", taskName, checkpointId);
-        enqueue(ChannelStateWriteRequest.abort(checkpointId, cause), true); // abort already started
         enqueue(
-                ChannelStateWriteRequest.abort(checkpointId, cause),
+                ChannelStateWriteRequest.abort(jobVertexID, subtaskIndex, checkpointId, cause),
+                true); // abort already started
+        enqueue(
+                ChannelStateWriteRequest.abort(jobVertexID, subtaskIndex, checkpointId, cause),
                 false); // abort enqueued but not started
         if (cleanup) {
             results.remove(checkpointId);
@@ -230,15 +262,14 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
         return results.get(checkpointId);
     }
 
-    public void open() {
-        executor.start();
-    }
-
     @Override
     public void close() throws IOException {
-        LOG.debug("close, dropping checkpoints {}", results.keySet());
-        results.clear();
-        executor.close();
+        if (wasClosed.compareAndSet(false, true)) {
+            LOG.debug("close, dropping checkpoints {}", results.keySet());
+            results.clear();
+            enqueue(ChannelStateWriteRequest.releaseSubtask(jobVertexID, subtaskIndex), false);
+            executor.releaseSubtask(jobVertexID, subtaskIndex);
+        }
     }
 
     private void enqueue(ChannelStateWriteRequest request, boolean atTheFront) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
index 40a587fd5c3..bd8d366f29a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -261,4 +262,6 @@ public interface Environment {
     default CheckpointStorageAccess getCheckpointStorageAccess() {
         throw new UnsupportedOperationException();
     }
+
+    ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
index 725182bebbf..204f3d3c074 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
@@ -97,11 +97,10 @@ public class NetworkActionsLogger {
     }
 
     public static void tracePersist(
-            String action, Buffer buffer, String taskName, Object channelInfo, long checkpointId) {
+            String action, Buffer buffer, Object channelInfo, long checkpointId) {
         if (LOG.isTraceEnabled()) {
             LOG.trace(
-                    "[{}] {} {}, checkpoint {} @ {}",
-                    taskName,
+                    "{} {}, checkpoint {} @ {}",
                     action,
                     buffer.toDebugString(INCLUDE_HASH),
                     checkpointId,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java
new file mode 100644
index 00000000000..db5ff6d416b
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * This class holds the all {@link ChannelStateWriteRequestExecutorFactory} objects for a task
+ * executor (manager).
+ */
+@ThreadSafe
+public class TaskExecutorChannelStateExecutorFactoryManager {
+
+    /** Logger for this class. */
+    private static final Logger LOG =
+            LoggerFactory.getLogger(TaskExecutorChannelStateExecutorFactoryManager.class);
+
+    private final Object lock = new Object();
+
+    @GuardedBy("lock")
+    private final Map<JobID, ChannelStateWriteRequestExecutorFactory> executorFactoryByJobId;
+
+    @GuardedBy("lock")
+    private boolean closed;
+
+    public TaskExecutorChannelStateExecutorFactoryManager() {
+        this.executorFactoryByJobId = new HashMap<>();
+        this.closed = false;
+    }
+
+    public ChannelStateWriteRequestExecutorFactory getOrCreateExecutorFactory(
+            @Nonnull JobID jobID) {
+        synchronized (lock) {
+            if (closed) {
+                throw new IllegalStateException(
+                        "TaskExecutorChannelStateExecutorFactoryManager is already closed and cannot "
+                                + "get a new executor factory.");
+            }
+            ChannelStateWriteRequestExecutorFactory factory = executorFactoryByJobId.get(jobID);
+            if (factory == null) {
+                LOG.info("Creating the channel state executor factory for job id {}", jobID);
+                factory = new ChannelStateWriteRequestExecutorFactory(jobID);
+                executorFactoryByJobId.put(jobID, factory);
+            }
+            return factory;
+        }
+    }
+
+    public void releaseResourcesForJob(@Nonnull JobID jobID) {
+        LOG.debug("Releasing the factory under job id {}", jobID);
+        synchronized (lock) {
+            if (closed) {
+                return;
+            }
+            executorFactoryByJobId.remove(jobID);
+        }
+    }
+
+    public void shutdown() {
+        synchronized (lock) {
+            if (closed) {
+                return;
+            }
+            closed = true;
+            executorFactoryByJobId.clear();
+            LOG.info("Shutting down TaskExecutorChannelStateExecutorFactoryManager.");
+        }
+    }
+
+    @VisibleForTesting
+    @Nullable
+    public ChannelStateWriteRequestExecutorFactory getFactoryByJobId(JobID jobId) {
+        synchronized (lock) {
+            return executorFactoryByJobId.get(jobId);
+        }
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
index 9c3dd3744af..e7226b9abcd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
@@ -97,6 +97,7 @@ import org.apache.flink.runtime.rpc.RpcServiceUtils;
 import org.apache.flink.runtime.security.token.DelegationTokenReceiverRepository;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
 import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
 import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
 import org.apache.flink.runtime.state.TaskLocalStateStore;
@@ -217,6 +218,12 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
     /** The changelog manager for this task, providing changelog storage per job. */
     private final TaskExecutorStateChangelogStoragesManager changelogStoragesManager;
 
+    /**
+     * The channel state executor factory manager for this task, providing channel state executor
+     * factory per job.
+     */
+    private final TaskExecutorChannelStateExecutorFactoryManager channelStateExecutorFactoryManager;
+
     /** Information provider for external resources. */
     private final ExternalResourceInfoProvider externalResourceInfoProvider;
 
@@ -319,6 +326,8 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
                 taskExecutorServices.getUnresolvedTaskManagerLocation();
         this.localStateStoresManager = taskExecutorServices.getTaskManagerStateStore();
         this.changelogStoragesManager = taskExecutorServices.getTaskManagerChangelogManager();
+        this.channelStateExecutorFactoryManager =
+                taskExecutorServices.getTaskManagerChannelStateManager();
         this.shuffleEnvironment = taskExecutorServices.getShuffleEnvironment();
         this.kvStateService = taskExecutorServices.getKvStateService();
         this.ioExecutor = taskExecutorServices.getIOExecutor();
@@ -479,6 +488,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
         }
 
         changelogStoragesManager.shutdown();
+        channelStateExecutorFactoryManager.shutdown();
 
         Preconditions.checkState(jobTable.isEmpty());
 
@@ -771,7 +781,8 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
                             taskManagerConfiguration,
                             taskMetricGroup,
                             partitionStateChecker,
-                            getRpcService().getScheduledExecutor());
+                            getRpcService().getScheduledExecutor(),
+                            channelStateExecutorFactoryManager.getOrCreateExecutorFactory(jobId));
 
             taskMetricGroup.gauge(MetricNames.IS_BACK_PRESSURED, task::isBackPressured);
 
@@ -1871,6 +1882,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
         taskManagerMetricGroup.removeJobMetricsGroup(jobId);
         changelogStoragesManager.releaseResourcesForJob(jobId);
         currentSlotOfferPerJob.remove(jobId);
+        channelStateExecutorFactoryManager.releaseResourcesForJob(jobId);
     }
 
     private void scheduleResultPartitionCleanup(JobID jobId) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
index d439146f263..bf4e11486b5 100755
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.rpc.FatalErrorHandler;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
 import org.apache.flink.runtime.shuffle.ShuffleServiceLoader;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
 import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
 import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
 import org.apache.flink.runtime.taskexecutor.slot.DefaultTimerService;
@@ -80,6 +81,7 @@ public class TaskManagerServices {
     private final JobLeaderService jobLeaderService;
     private final TaskExecutorLocalStateStoresManager taskManagerStateStore;
     private final TaskExecutorStateChangelogStoragesManager taskManagerChangelogManager;
+    private final TaskExecutorChannelStateExecutorFactoryManager taskManagerChannelStateManager;
     private final TaskEventDispatcher taskEventDispatcher;
     private final ExecutorService ioExecutor;
     private final LibraryCacheManager libraryCacheManager;
@@ -98,6 +100,7 @@ public class TaskManagerServices {
             JobLeaderService jobLeaderService,
             TaskExecutorLocalStateStoresManager taskManagerStateStore,
             TaskExecutorStateChangelogStoragesManager taskManagerChangelogManager,
+            TaskExecutorChannelStateExecutorFactoryManager taskManagerChannelStateManager,
             TaskEventDispatcher taskEventDispatcher,
             ExecutorService ioExecutor,
             LibraryCacheManager libraryCacheManager,
@@ -116,6 +119,7 @@ public class TaskManagerServices {
         this.jobLeaderService = Preconditions.checkNotNull(jobLeaderService);
         this.taskManagerStateStore = Preconditions.checkNotNull(taskManagerStateStore);
         this.taskManagerChangelogManager = Preconditions.checkNotNull(taskManagerChangelogManager);
+        this.taskManagerChannelStateManager = taskManagerChannelStateManager;
         this.taskEventDispatcher = Preconditions.checkNotNull(taskEventDispatcher);
         this.ioExecutor = Preconditions.checkNotNull(ioExecutor);
         this.libraryCacheManager = Preconditions.checkNotNull(libraryCacheManager);
@@ -171,6 +175,10 @@ public class TaskManagerServices {
         return taskManagerChangelogManager;
     }
 
+    public TaskExecutorChannelStateExecutorFactoryManager getTaskManagerChannelStateManager() {
+        return taskManagerChannelStateManager;
+    }
+
     public TaskEventDispatcher getTaskEventDispatcher() {
         return taskEventDispatcher;
     }
@@ -341,6 +349,9 @@ public class TaskManagerServices {
         final TaskExecutorStateChangelogStoragesManager changelogStoragesManager =
                 new TaskExecutorStateChangelogStoragesManager();
 
+        final TaskExecutorChannelStateExecutorFactoryManager channelStateExecutorFactoryManager =
+                new TaskExecutorChannelStateExecutorFactoryManager();
+
         final boolean failOnJvmMetaspaceOomError =
                 taskManagerServicesConfiguration
                         .getConfiguration()
@@ -382,6 +393,7 @@ public class TaskManagerServices {
                 jobLeaderService,
                 taskStateManager,
                 changelogStoragesManager,
+                channelStateExecutorFactoryManager,
                 taskEventDispatcher,
                 ioExecutor,
                 libraryCacheManager,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index 74ce5e3fd32..e1d28ab359c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -106,6 +107,8 @@ public class RuntimeEnvironment implements Environment {
 
     @Nullable private CheckpointStorageAccess checkpointStorageAccess;
 
+    ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
     // ------------------------------------------------------------------------
 
     public RuntimeEnvironment(
@@ -135,7 +138,8 @@ public class RuntimeEnvironment implements Environment {
             TaskManagerRuntimeInfo taskManagerInfo,
             TaskMetricGroup metrics,
             Task containingTask,
-            ExternalResourceInfoProvider externalResourceInfoProvider) {
+            ExternalResourceInfoProvider externalResourceInfoProvider,
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
 
         this.jobId = checkNotNull(jobId);
         this.jobVertexId = checkNotNull(jobVertexId);
@@ -164,6 +168,7 @@ public class RuntimeEnvironment implements Environment {
         this.containingTask = containingTask;
         this.metrics = metrics;
         this.externalResourceInfoProvider = checkNotNull(externalResourceInfoProvider);
+        this.channelStateExecutorFactory = checkNotNull(channelStateExecutorFactory);
     }
 
     // ------------------------------------------------------------------------
@@ -368,4 +373,9 @@ public class RuntimeEnvironment implements Environment {
         return checkNotNull(
                 checkpointStorageAccess, "checkpointStorage has not been initialized yet!");
     }
+
+    @Override
+    public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+        return channelStateExecutorFactory;
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 059534f346e..acdf1b00d56 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointStoreUtil;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -262,6 +263,9 @@ public class Task
     /** Future that is completed once {@link #run()} exits. */
     private final CompletableFuture<ExecutionState> terminationFuture = new CompletableFuture<>();
 
+    /** The factory of channel state write request executor. */
+    private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
     // ------------------------------------------------------------------------
     //  Fields that control the task execution. All these fields are volatile
     //  (which means that they introduce memory barriers), to establish
@@ -324,7 +328,8 @@ public class Task
             TaskManagerRuntimeInfo taskManagerConfig,
             @Nonnull TaskMetricGroup metricGroup,
             PartitionProducerStateChecker partitionProducerStateChecker,
-            Executor executor) {
+            Executor executor,
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
 
         Preconditions.checkNotNull(jobInformation);
         Preconditions.checkNotNull(taskInformation);
@@ -382,6 +387,7 @@ public class Task
         this.partitionProducerStateChecker =
                 Preconditions.checkNotNull(partitionProducerStateChecker);
         this.executor = Preconditions.checkNotNull(executor);
+        this.channelStateExecutorFactory = channelStateExecutorFactory;
 
         // create the reader and writer structures
 
@@ -708,7 +714,8 @@ public class Task
                             taskManagerConfig,
                             metrics,
                             this,
-                            externalResourceInfoProvider);
+                            externalResourceInfoProvider,
+                            channelStateExecutorFactory);
 
             // Make sure the user code classloader is accessible thread-locally.
             // We are setting the correct context class loader before instantiating the invokable
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
index d135cbdd500..c2c3efd4b34 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelSta
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.InputChannelStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
@@ -34,28 +35,42 @@ import org.apache.flink.util.function.RunnableWithException;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.io.TempDir;
 
+import javax.annotation.Nullable;
+
 import java.io.ByteArrayOutputStream;
 import java.io.DataOutputStream;
 import java.io.File;
 import java.io.IOException;
 import java.nio.file.Path;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.Map;
 import java.util.Random;
+import java.util.Set;
 import java.util.stream.IntStream;
 
 import static java.util.Collections.singletonList;
 import static org.apache.flink.core.fs.Path.fromLocalFile;
 import static org.apache.flink.core.fs.local.LocalFileSystem.getSharedInstance;
 import static org.apache.flink.core.memory.MemorySegmentFactory.wrap;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertAllSubtaskDoneNormally;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertAllSubtaskNotDone;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertCheckpointFailureReason;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertHasSpecialCause;
 import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateCheckpointWriter} test. */
 class ChannelStateCheckpointWriterTest {
     private static final RunnableWithException NO_OP_RUNNABLE = () -> {};
     private final Random random = new Random();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
+    private static final SubtaskID SUBTASK_ID = SubtaskID.of(JOB_VERTEX_ID, SUBTASK_INDEX);
 
     @TempDir private Path temporaryFolder;
 
@@ -89,8 +104,8 @@ class ChannelStateCheckpointWriterTest {
                 write(writer, channels[channel], getData(numBytesPerWrite));
             }
         }
-        writer.completeInput();
-        writer.completeOutput();
+        writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+        writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
 
         for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
             assertThat(handle.getStateSize())
@@ -120,9 +135,9 @@ class ChannelStateCheckpointWriterTest {
                 new NetworkBuffer(
                         MemorySegmentFactory.allocateUnpooledSegment(threshold / 2),
                         FreeingBufferRecycler.INSTANCE);
-        writer.writeInput(new InputChannelInfo(1, 2), buffer);
-        writer.completeOutput();
-        writer.completeInput();
+        writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, new InputChannelInfo(1, 2), buffer);
+        writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
+        writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
         assertThat(result.isDone()).isTrue();
         assertThat(checkpointsDir).isEmptyDirectory();
         assertThat(sharedStateDir).isEmptyDirectory();
@@ -139,8 +154,8 @@ class ChannelStateCheckpointWriterTest {
                     }
                 };
         ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult(), stream);
-        writer.completeOutput();
-        writer.completeInput();
+        writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
+        writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
         assertThat(stream.isClosed()).isTrue();
     }
 
@@ -149,9 +164,9 @@ class ChannelStateCheckpointWriterTest {
         ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult());
         NetworkBuffer buffer =
                 new NetworkBuffer(
-                        MemorySegmentFactory.allocateUnpooledSegment(10, null),
+                        MemorySegmentFactory.allocateUnpooledSegment(10),
                         FreeingBufferRecycler.INSTANCE);
-        writer.writeInput(new InputChannelInfo(1, 2), buffer);
+        writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, new InputChannelInfo(1, 2), buffer);
         assertThat(buffer.isRecycled()).isTrue();
     }
 
@@ -174,29 +189,194 @@ class ChannelStateCheckpointWriterTest {
         FlushRecorder dataStream = new FlushRecorder();
         final ChannelStateCheckpointWriter writer =
                 new ChannelStateCheckpointWriter(
-                        "dummy task",
-                        0,
+                        Collections.singleton(SUBTASK_ID),
                         1L,
-                        new ChannelStateWriteResult(),
                         new ChannelStateSerializerImpl(),
                         NO_OP_RUNNABLE,
                         new MemoryCheckpointOutputStream(42),
                         dataStream);
-
-        writer.completeInput();
-        writer.completeOutput();
+        writer.registerSubtaskResult(SUBTASK_ID, new ChannelStateWriteResult());
+        writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+        writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
 
         assertThat(dataStream.flushed).isTrue();
     }
 
     @Test
     void testResultCompletion() throws Exception {
-        ChannelStateWriteResult result = new ChannelStateWriteResult();
-        ChannelStateCheckpointWriter writer = createWriter(result);
-        writer.completeInput();
-        assertThat(result.isDone()).isFalse();
-        writer.completeOutput();
-        assertThat(result.isDone()).isTrue();
+        for (int maxSubtasksPerChannelStateFile = 1;
+                maxSubtasksPerChannelStateFile < 10;
+                maxSubtasksPerChannelStateFile++) {
+            testMultiTaskCompletionAndAssertResult(maxSubtasksPerChannelStateFile);
+        }
+    }
+
+    private void testMultiTaskCompletionAndAssertResult(int maxSubtasksPerChannelStateFile)
+            throws Exception {
+        Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+        for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+            subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+        }
+        MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+        ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+        for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+            writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+        }
+
+        for (SubtaskID subtaskID : subtasks.keySet()) {
+            assertAllSubtaskNotDone(subtasks.values());
+            assertThat(stream.isClosed()).isFalse();
+            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+            assertAllSubtaskNotDone(subtasks.values());
+            assertThat(stream.isClosed()).isFalse();
+            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+        }
+        assertThat(stream.isClosed()).isTrue();
+        assertAllSubtaskDoneNormally(subtasks.values());
+    }
+
+    @Test
+    void testTaskUnregister() throws Exception {
+        testTaskUnregisterAndAssertResult(2);
+        testTaskUnregisterAndAssertResult(3);
+        testTaskUnregisterAndAssertResult(5);
+        testTaskUnregisterAndAssertResult(10);
+    }
+
+    private void testTaskUnregisterAndAssertResult(int maxSubtasksPerChannelStateFile)
+            throws Exception {
+        Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+        for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+            subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+        }
+        MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+        ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+        SubtaskID unregisterSubtask = null;
+        Iterator<Map.Entry<SubtaskID, ChannelStateWriteResult>> iterator =
+                subtasks.entrySet().iterator();
+        while (iterator.hasNext()) {
+            Map.Entry<SubtaskID, ChannelStateWriteResult> entry = iterator.next();
+            if (unregisterSubtask == null) {
+                unregisterSubtask = entry.getKey();
+                iterator.remove();
+                continue;
+            }
+            writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+        }
+
+        for (SubtaskID subtaskID : subtasks.keySet()) {
+            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+        }
+        assertAllSubtaskNotDone(subtasks.values());
+        assertThat(stream.isClosed()).isFalse();
+
+        assert unregisterSubtask != null;
+        writer.releaseSubtask(unregisterSubtask);
+        assertThat(stream.isClosed()).isTrue();
+        assertAllSubtaskDoneNormally(subtasks.values());
+    }
+
+    @Test
+    void testTaskFailThenCompleteOtherTask() {
+        testTaskFailAfterAllTaskRegisteredAndAssertResult(2);
+        testTaskFailAfterAllTaskRegisteredAndAssertResult(3);
+        testTaskFailAfterAllTaskRegisteredAndAssertResult(5);
+        testTaskFailAfterAllTaskRegisteredAndAssertResult(10);
+    }
+
+    private void testTaskFailAfterAllTaskRegisteredAndAssertResult(
+            int maxSubtasksPerChannelStateFile) {
+        Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+        for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+            subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+        }
+        MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+        ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+        SubtaskID firstSubtask = null;
+        for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+            if (firstSubtask == null) {
+                firstSubtask = entry.getKey();
+            }
+            writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+        }
+        assertThat(stream.isClosed()).isFalse();
+
+        assert firstSubtask != null;
+        writer.fail(
+                firstSubtask.getJobVertexID(), firstSubtask.getSubtaskIndex(), new TestException());
+        assertThat(stream.isClosed()).isTrue();
+
+        for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+            if (firstSubtask.equals(entry.getKey())) {
+                assertHasSpecialCause(entry.getValue(), TestException.class);
+                continue;
+            }
+            assertCheckpointFailureReason(entry.getValue(), CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
+        }
+    }
+
+    @Test
+    void testCloseGetHandleThrowException() throws Exception {
+        Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+        for (int i = 0; i < 5; i++) {
+            subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+        }
+        CloseExceptionOutputStream stream = new CloseExceptionOutputStream();
+        ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+        for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+            SubtaskID subtaskID = entry.getKey();
+            writer.registerSubtaskResult(subtaskID, entry.getValue());
+            NetworkBuffer buffer =
+                    new NetworkBuffer(
+                            MemorySegmentFactory.allocateUnpooledSegment(10),
+                            FreeingBufferRecycler.INSTANCE);
+            writer.writeInput(
+                    subtaskID.getJobVertexID(),
+                    subtaskID.getSubtaskIndex(),
+                    new InputChannelInfo(1, 2),
+                    buffer);
+        }
+
+        for (SubtaskID subtaskID : subtasks.keySet()) {
+            assertAllSubtaskNotDone(subtasks.values());
+            assertThat(stream.isClosed()).isFalse();
+            writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+            assertAllSubtaskNotDone(subtasks.values());
+            assertThat(stream.isClosed()).isFalse();
+            writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+        }
+        assertThat(stream.isClosed()).isTrue();
+        for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+            assertThatThrownBy(() -> entry.getValue().getInputChannelStateHandles().get())
+                    .cause()
+                    .isInstanceOf(IOException.class)
+                    .hasMessage("Test closeAndGetHandle exception.");
+            assertThatThrownBy(() -> entry.getValue().getResultSubpartitionStateHandles().get())
+                    .cause()
+                    .isInstanceOf(IOException.class)
+                    .hasMessage("Test closeAndGetHandle exception.");
+        }
+    }
+
+    @Test
+    void testRegisterSubtaskAfterWriterDone() {
+        Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+        SubtaskID subtask0 = SubtaskID.of(JOB_VERTEX_ID, 0);
+        SubtaskID subtask1 = SubtaskID.of(JOB_VERTEX_ID, 1);
+        subtasks.put(subtask0, new ChannelStateWriteResult());
+        subtasks.put(subtask1, new ChannelStateWriteResult());
+        MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+        ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+        writer.fail(new JobVertexID(), 0, new TestException());
+        assertThatThrownBy(
+                        () -> writer.registerSubtaskResult(subtask0, new ChannelStateWriteResult()))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("The write is done.");
+        assertThatThrownBy(
+                        () -> writer.registerSubtaskResult(subtask1, new ChannelStateWriteResult()))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("The write is done.");
     }
 
     @Test
@@ -214,8 +394,8 @@ class ChannelStateCheckpointWriterTest {
                 write(writer, e.getKey(), getData(numBytes));
             }
         }
-        writer.completeInput();
-        writer.completeOutput();
+        writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+        writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
 
         for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
             int headerSize = Integer.BYTES;
@@ -245,7 +425,7 @@ class ChannelStateCheckpointWriterTest {
                         FreeingBufferRecycler.INSTANCE,
                         Buffer.DataType.DATA_BUFFER,
                         segment.size());
-        writer.writeInput(channelInfo, buffer);
+        writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, channelInfo, buffer);
     }
 
     private ChannelStateCheckpointWriter createWriter(ChannelStateWriteResult result) {
@@ -254,13 +434,28 @@ class ChannelStateCheckpointWriterTest {
 
     private ChannelStateCheckpointWriter createWriter(
             ChannelStateWriteResult result, CheckpointStateOutputStream stream) {
+        ChannelStateCheckpointWriter writer =
+                createWriter(stream, Collections.singleton(SUBTASK_ID));
+        writer.registerSubtaskResult(SUBTASK_ID, result);
+        return writer;
+    }
+
+    private ChannelStateCheckpointWriter createWriter(
+            CheckpointStateOutputStream stream, Set<SubtaskID> subtasks) {
         return new ChannelStateCheckpointWriter(
-                "dummy task",
-                0,
-                1L,
-                result,
-                stream,
-                new ChannelStateSerializerImpl(),
-                NO_OP_RUNNABLE);
+                subtasks, 1L, stream, new ChannelStateSerializerImpl(), NO_OP_RUNNABLE);
+    }
+}
+
+/** The output stream that throws an exception when close or closeAndGetHandle. */
+class CloseExceptionOutputStream extends MemoryCheckpointOutputStream {
+    public CloseExceptionOutputStream() {
+        super(1000);
+    }
+
+    @Nullable
+    @Override
+    public StreamStateHandle closeAndGetHandle() throws IOException {
+        throw new IOException("Test closeAndGetHandle exception.");
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
index 3d80f290856..5ece5423a39 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
@@ -19,29 +19,42 @@ package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
-import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
 import java.util.function.Function;
 
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertCheckpointFailureReason;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertHasSpecialCause;
 import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** {@link ChannelStateWriteRequestDispatcherImpl} test. */
 class ChannelStateWriteRequestDispatcherImplTest {
 
+    private static final JobID JOB_ID = new JobID();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
+
     @Test
     void testPartialInputChannelStateWrite() throws Exception {
         testBuffersRecycled(
                 buffers ->
                         ChannelStateWriteRequest.write(
+                                JOB_VERTEX_ID,
+                                SUBTASK_INDEX,
                                 1L,
                                 new InputChannelInfo(1, 2),
                                 ofElements(Buffer::recycleBuffer, buffers)));
@@ -52,52 +65,167 @@ class ChannelStateWriteRequestDispatcherImplTest {
         testBuffersRecycled(
                 buffers ->
                         ChannelStateWriteRequest.write(
-                                1L, new ResultSubpartitionInfo(1, 2), buffers));
+                                JOB_VERTEX_ID,
+                                SUBTASK_INDEX,
+                                1L,
+                                new ResultSubpartitionInfo(1, 2),
+                                buffers));
+    }
+
+    private void testBuffersRecycled(
+            Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
+        ChannelStateWriteRequestDispatcher dispatcher =
+                new ChannelStateWriteRequestDispatcherImpl(
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
+                        new ChannelStateSerializerImpl());
+        ChannelStateWriteResult result = new ChannelStateWriteResult();
+        dispatcher.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+        dispatcher.dispatch(
+                ChannelStateWriteRequest.start(
+                        JOB_VERTEX_ID,
+                        SUBTASK_INDEX,
+                        1L,
+                        result,
+                        CheckpointStorageLocationReference.getDefault()));
+
+        result.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
+        result.getInputChannelStateHandles().completeExceptionally(new TestException());
+
+        NetworkBuffer[] buffers = new NetworkBuffer[] {buffer(), buffer()};
+        dispatcher.dispatch(requestBuilder.apply(buffers));
+        for (NetworkBuffer buffer : buffers) {
+            assertThat(buffer.isRecycled()).isTrue();
+        }
     }
 
     @Test
-    void testConcurrentUnalignedCheckpoint() throws Exception {
+    void testStartNewCheckpointForSameSubtask() throws Exception {
+        testStartNewCheckpointAndCheckOldCheckpointResult(false);
+    }
+
+    @Test
+    void testStartNewCheckpointForDifferentSubtask() throws Exception {
+        testStartNewCheckpointAndCheckOldCheckpointResult(true);
+    }
+
+    private void testStartNewCheckpointAndCheckOldCheckpointResult(boolean isDifferentSubtask)
+            throws Exception {
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
-                        "dummy task",
-                        0,
-                        getStreamFactoryFactory(),
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
                         new ChannelStateSerializerImpl());
         ChannelStateWriteResult result = new ChannelStateWriteResult();
+        processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+        JobVertexID newJobVertex = JOB_VERTEX_ID;
+        if (isDifferentSubtask) {
+            newJobVertex = new JobVertexID();
+            processor.dispatch(
+                    ChannelStateWriteRequest.registerSubtask(newJobVertex, SUBTASK_INDEX));
+        }
         processor.dispatch(
                 ChannelStateWriteRequest.start(
-                        1L, result, CheckpointStorageLocationReference.getDefault()));
+                        JOB_VERTEX_ID,
+                        SUBTASK_INDEX,
+                        1L,
+                        result,
+                        CheckpointStorageLocationReference.getDefault()));
         assertThat(result.isDone()).isFalse();
 
         processor.dispatch(
                 ChannelStateWriteRequest.start(
+                        newJobVertex,
+                        SUBTASK_INDEX,
                         2L,
                         new ChannelStateWriteResult(),
                         CheckpointStorageLocationReference.getDefault()));
-        assertThat(result.isDone()).isTrue();
-        assertThat(result.getInputChannelStateHandles()).isCompletedExceptionally();
-        assertThat(result.getResultSubpartitionStateHandles()).isCompletedExceptionally();
+        assertCheckpointFailureReason(result, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
     }
 
-    private void testBuffersRecycled(
-            Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
-        ChannelStateWriteRequestDispatcher dispatcher =
+    @Test
+    void testStartOldCheckpointForSameSubtask() throws Exception {
+        testStartOldCheckpointAfterNewCheckpointAborted(false);
+    }
+
+    @Test
+    void testStartOldCheckpointForDifferentSubtask() throws Exception {
+        testStartOldCheckpointAfterNewCheckpointAborted(true);
+    }
+
+    private void testStartOldCheckpointAfterNewCheckpointAborted(boolean isDifferentSubtask)
+            throws Exception {
+        ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
-                        "dummy task",
-                        0,
-                        new MemoryBackendCheckpointStorageAccess(new JobID(), null, null, 1),
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
                         new ChannelStateSerializerImpl());
+        processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+        JobVertexID newJobVertex = JOB_VERTEX_ID;
+        if (isDifferentSubtask) {
+            newJobVertex = new JobVertexID();
+            processor.dispatch(
+                    ChannelStateWriteRequest.registerSubtask(newJobVertex, SUBTASK_INDEX));
+        }
+        processor.dispatch(
+                ChannelStateWriteRequest.abort(
+                        JOB_VERTEX_ID, SUBTASK_INDEX, 2L, new TestException()));
+
         ChannelStateWriteResult result = new ChannelStateWriteResult();
-        dispatcher.dispatch(
+        processor.dispatch(
                 ChannelStateWriteRequest.start(
-                        1L, result, CheckpointStorageLocationReference.getDefault()));
+                        newJobVertex,
+                        SUBTASK_INDEX,
+                        1L,
+                        result,
+                        CheckpointStorageLocationReference.getDefault()));
+        assertCheckpointFailureReason(result, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
+    }
 
-        result.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
-        result.getInputChannelStateHandles().completeExceptionally(new TestException());
+    @Test
+    void testAbortCheckpointAndCheckAllException() throws Exception {
+        testAbortCheckpointAndCheckAllException(1);
+        testAbortCheckpointAndCheckAllException(2);
+        testAbortCheckpointAndCheckAllException(3);
+        testAbortCheckpointAndCheckAllException(5);
+        testAbortCheckpointAndCheckAllException(10);
+    }
 
-        NetworkBuffer[] buffers = new NetworkBuffer[] {buffer(), buffer()};
-        dispatcher.dispatch(requestBuilder.apply(buffers));
-        assertThat(buffers).allMatch(NetworkBuffer::isRecycled);
+    private void testAbortCheckpointAndCheckAllException(int numberOfSubtask) throws Exception {
+        ChannelStateWriteRequestDispatcher processor =
+                new ChannelStateWriteRequestDispatcherImpl(
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
+                        new ChannelStateSerializerImpl());
+        List<ChannelStateWriteResult> results = new ArrayList<>(numberOfSubtask);
+        for (int i = 0; i < numberOfSubtask; i++) {
+            processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, i));
+        }
+        long checkpointId = 1L;
+        int abortedSubtaskIndex = new Random().nextInt(numberOfSubtask);
+        processor.dispatch(
+                ChannelStateWriteRequest.abort(
+                        JOB_VERTEX_ID, abortedSubtaskIndex, checkpointId, new TestException()));
+        for (int i = 0; i < numberOfSubtask; i++) {
+            ChannelStateWriteResult result = new ChannelStateWriteResult();
+            results.add(result);
+            processor.dispatch(
+                    ChannelStateWriteRequest.start(
+                            JOB_VERTEX_ID,
+                            i,
+                            checkpointId,
+                            result,
+                            CheckpointStorageLocationReference.getDefault()));
+        }
+        assertThat(results).allMatch(ChannelStateWriteResult::isDone);
+        for (int i = 0; i < numberOfSubtask; i++) {
+            ChannelStateWriteResult result = results.get(i);
+            if (i == abortedSubtaskIndex) {
+                assertHasSpecialCause(result, TestException.class);
+            } else {
+                assertCheckpointFailureReason(result, CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
+            }
+        }
     }
 
     private NetworkBuffer buffer() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
index 4b59e1235e6..36de469bbfa 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
@@ -17,12 +17,15 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
 import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
 import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
@@ -43,7 +46,6 @@ import static java.util.Optional.of;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput;
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.assertj.core.api.Assertions.fail;
 
 /** {@link ChannelStateWriteRequestDispatcherImpl} tests. */
@@ -90,6 +92,10 @@ public class ChannelStateWriteRequestDispatcherTest {
                 new Object[] {of(IllegalStateException.class), asList(start(), start())});
     }
 
+    private static final JobID JOB_ID = new JobID();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
+
     @Parameter public Optional<Class<Exception>> expectedException;
 
     @Parameter(value = 1)
@@ -98,15 +104,17 @@ public class ChannelStateWriteRequestDispatcherTest {
     private static final long CHECKPOINT_ID = 42L;
 
     private static CheckpointInProgressRequest completeOut() {
-        return completeOutput(CHECKPOINT_ID);
+        return completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX, CHECKPOINT_ID);
     }
 
     private static CheckpointInProgressRequest completeIn() {
-        return completeInput(CHECKPOINT_ID);
+        return completeInput(JOB_VERTEX_ID, SUBTASK_INDEX, CHECKPOINT_ID);
     }
 
     private static ChannelStateWriteRequest writeIn() {
         return write(
+                JOB_VERTEX_ID,
+                SUBTASK_INDEX,
                 CHECKPOINT_ID,
                 new InputChannelInfo(1, 1),
                 CloseableIterator.ofElement(
@@ -118,6 +126,8 @@ public class ChannelStateWriteRequestDispatcherTest {
 
     private static ChannelStateWriteRequest writeOut() {
         return write(
+                JOB_VERTEX_ID,
+                SUBTASK_INDEX,
                 CHECKPOINT_ID,
                 new ResultSubpartitionInfo(1, 1),
                 new NetworkBuffer(
@@ -128,7 +138,12 @@ public class ChannelStateWriteRequestDispatcherTest {
     private static ChannelStateWriteRequest writeOutFuture() {
         CompletableFuture<List<Buffer>> outFuture = new CompletableFuture<>();
         ChannelStateWriteRequest writeRequest =
-                write(CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1), outFuture);
+                write(
+                        JOB_VERTEX_ID,
+                        SUBTASK_INDEX,
+                        CHECKPOINT_ID,
+                        new ResultSubpartitionInfo(1, 1),
+                        outFuture);
         outFuture.complete(
                 singletonList(
                         new NetworkBuffer(
@@ -137,8 +152,14 @@ public class ChannelStateWriteRequestDispatcherTest {
         return writeRequest;
     }
 
+    private static SubtaskRegisterRequest register() {
+        return new SubtaskRegisterRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+    }
+
     private static CheckpointStartRequest start() {
         return new CheckpointStartRequest(
+                JOB_VERTEX_ID,
+                SUBTASK_INDEX,
                 CHECKPOINT_ID,
                 new ChannelStateWriteResult(),
                 new CheckpointStorageLocationReference(new byte[] {1}));
@@ -148,11 +169,11 @@ public class ChannelStateWriteRequestDispatcherTest {
     void doRun() {
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
-                        "dummy task",
-                        0,
-                        getStreamFactoryFactory(),
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
                         new ChannelStateSerializerImpl());
         try {
+            processor.dispatch(register());
             for (ChannelStateWriteRequest request : requests) {
                 processor.dispatch(request);
             }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java
new file mode 100644
index 00000000000..4f67527be2d
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Random;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link ChannelStateWriteRequestExecutorFactory} */
+public class ChannelStateWriteRequestExecutorFactoryTest {
+
+    private static final CheckpointStorage CHECKPOINT_STORAGE = new JobManagerCheckpointStorage();
+
+    @Test
+    void testReuseExecutorForSameJobId() {
+        assertReuseExecutor(1);
+        assertReuseExecutor(2);
+        assertReuseExecutor(3);
+        assertReuseExecutor(5);
+        assertReuseExecutor(10);
+    }
+
+    private void assertReuseExecutor(int maxSubtasksPerChannelStateFile) {
+        JobID JOB_ID = new JobID();
+        Random RANDOM = new Random();
+        ChannelStateWriteRequestExecutorFactory executorFactory =
+                new ChannelStateWriteRequestExecutorFactory(JOB_ID);
+        int numberOfTasks = 100;
+
+        ChannelStateWriteRequestExecutor currentExecutor = null;
+        for (int i = 0; i < numberOfTasks; i++) {
+            ChannelStateWriteRequestExecutor newExecutor =
+                    executorFactory.getOrCreateExecutor(
+                            new JobVertexID(),
+                            RANDOM.nextInt(numberOfTasks),
+                            CHECKPOINT_STORAGE,
+                            maxSubtasksPerChannelStateFile);
+            if (i % maxSubtasksPerChannelStateFile == 0) {
+                assertThat(newExecutor)
+                        .as("Factory should create the new executor.")
+                        .isNotSameAs(currentExecutor);
+                currentExecutor = newExecutor;
+            } else {
+                assertThat(newExecutor)
+                        .as("Factory should reuse the old executor.")
+                        .isSameAs(currentExecutor);
+            }
+        }
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
index b7aa7afb682..5c248c73895 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
@@ -17,21 +17,30 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.function.BiConsumerWithException;
 
 import org.junit.jupiter.api.Test;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.io.IOException;
+import java.util.ArrayDeque;
 import java.util.Collections;
+import java.util.Deque;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestDispatcher.NO_OP;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -40,7 +49,9 @@ import static org.assertj.core.api.Assertions.fail;
 /** {@link ChannelStateWriteRequestExecutorImpl} test. */
 class ChannelStateWriteRequestExecutorImplTest {
 
-    private static final String TASK_NAME = "test task";
+    private static final JobID JOB_ID = new JobID();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
 
     @Test
     void testCloseAfterSubmit() {
@@ -74,9 +85,10 @@ class ChannelStateWriteRequestExecutorImplTest {
             throws Exception {
         WorkerClosingDeque closingDeque = new WorkerClosingDeque();
         ChannelStateWriteRequestExecutorImpl worker =
-                new ChannelStateWriteRequestExecutorImpl(TASK_NAME, NO_OP, closingDeque);
+                new ChannelStateWriteRequestExecutorImpl(NO_OP, closingDeque, 5, e -> {});
         closingDeque.setWorker(worker);
-        TestWriteRequest request = new TestWriteRequest();
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
         requestFun.accept(worker, request);
         assertThat(closingDeque).isEmpty();
         assertThat(request.isCancelled()).isFalse();
@@ -87,11 +99,13 @@ class ChannelStateWriteRequestExecutorImplTest {
                             ChannelStateWriteRequestExecutor, ChannelStateWriteRequest, Exception>
                     submitAction)
             throws Exception {
-        TestWriteRequest request = new TestWriteRequest();
-        LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
+        TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+        Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
         try {
-            submitAction.accept(
-                    new ChannelStateWriteRequestExecutorImpl(TASK_NAME, NO_OP, deque), request);
+            ChannelStateWriteRequestExecutorImpl executor =
+                    new ChannelStateWriteRequestExecutorImpl(NO_OP, deque, 5, e -> {});
+            executor.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+            submitAction.accept(executor, request);
         } catch (IllegalStateException e) {
             // expected: executor not started;
             return;
@@ -105,14 +119,14 @@ class ChannelStateWriteRequestExecutorImplTest {
     @Test
     @SuppressWarnings("CallToThreadRun")
     void testCleanup() throws IOException {
-        TestWriteRequest request = new TestWriteRequest();
-        LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
+        TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+        Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
         deque.add(request);
         TestRequestDispatcher requestProcessor = new TestRequestDispatcher();
         ChannelStateWriteRequestExecutorImpl worker =
-                new ChannelStateWriteRequestExecutorImpl(TASK_NAME, requestProcessor, deque);
-
-        worker.close();
+                new ChannelStateWriteRequestExecutorImpl(requestProcessor, deque, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         worker.run();
 
         assertThat(requestProcessor.isStopped()).isTrue();
@@ -123,16 +137,20 @@ class ChannelStateWriteRequestExecutorImplTest {
     @Test
     void testIgnoresInterruptsWhileRunning() throws Exception {
         TestRequestDispatcher requestProcessor = new TestRequestDispatcher();
-        LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
-        try (ChannelStateWriteRequestExecutorImpl worker =
-                new ChannelStateWriteRequestExecutorImpl(TASK_NAME, requestProcessor, deque)) {
+        Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(requestProcessor, deque, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        try {
             worker.start();
             worker.getThread().interrupt();
-            worker.submit(new TestWriteRequest());
+            worker.submit(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
             worker.getThread().interrupt();
             while (!deque.isEmpty()) {
                 Thread.sleep(100);
             }
+        } finally {
+            worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         }
     }
 
@@ -141,28 +159,123 @@ class ChannelStateWriteRequestExecutorImplTest {
         long checkpointId = 1L;
         ChannelStateWriteRequestDispatcher processor =
                 new ChannelStateWriteRequestDispatcherImpl(
-                        "dummy task",
-                        0,
-                        getStreamFactoryFactory(),
+                        new JobManagerCheckpointStorage(),
+                        JOB_ID,
                         new ChannelStateSerializerImpl());
-        try (ChannelStateWriteRequestExecutorImpl worker =
-                new ChannelStateWriteRequestExecutorImpl(TASK_NAME, processor)) {
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(processor, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        try {
             worker.start();
             worker.submit(
                     new CheckpointStartRequest(
+                            JOB_VERTEX_ID,
+                            SUBTASK_INDEX,
                             checkpointId,
                             new ChannelStateWriter.ChannelStateWriteResult(),
                             CheckpointStorageLocationReference.getDefault()));
             worker.submit(
                     ChannelStateWriteRequest.write(
+                            JOB_VERTEX_ID,
+                            SUBTASK_INDEX,
                             checkpointId,
                             new ResultSubpartitionInfo(0, 0),
                             new CompletableFuture<>()));
             worker.submit(
                     ChannelStateWriteRequest.write(
+                            JOB_VERTEX_ID,
+                            SUBTASK_INDEX,
                             checkpointId,
                             new ResultSubpartitionInfo(0, 0),
                             new CompletableFuture<>()));
+        } finally {
+            worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        }
+    }
+
+    @Test
+    void testSkipUnreadyDataFuture() throws Exception {
+        int subtaskIndex0 = 0;
+        int subtaskIndex1 = 1;
+
+        Queue<ChannelStateWriteRequest> firstBatchRequests = new LinkedList<>();
+        Queue<ChannelStateWriteRequest> secondBatchRequests = new LinkedList<>();
+        CompletableFuture<List<Buffer>> dataFuture = new CompletableFuture<>();
+        int firstBatchSubtask1Count = 3;
+        int subtask0Count = 4;
+        int subtask1Count = 4;
+
+        {
+            // Generate the first batch requests.
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+            // Add a data future request, all subsequent requests of subtaskIndex0 should be blocked
+            // before this future is completed.
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0, dataFuture));
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+            firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+
+            // Generate the second batch requests.
+            secondBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+            secondBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+        }
+
+        CompletableFuture<Void> firstBatchFuture = new CompletableFuture<>();
+        CompletableFuture<Void> allReceivedFuture = new CompletableFuture<>();
+
+        // The subtask register request cannot be count.
+        AtomicInteger subtask0ReceivedCounter = new AtomicInteger(-1);
+        AtomicInteger subtask1ReceivedCounter = new AtomicInteger(-1);
+
+        TestRequestDispatcher throwingRequestProcessor =
+                new TestRequestDispatcher() {
+                    @Override
+                    public void dispatch(ChannelStateWriteRequest request) {
+                        if (request.getSubtaskIndex() == subtaskIndex0) {
+                            subtask0ReceivedCounter.incrementAndGet();
+                        } else if (request.getSubtaskIndex() == subtaskIndex1) {
+                            if (subtask1ReceivedCounter.incrementAndGet()
+                                    == firstBatchSubtask1Count) {
+                                firstBatchFuture.complete(null);
+                            }
+                        } else {
+                            throw new IllegalStateException(
+                                    String.format(
+                                            "Unknown subtask index %s.",
+                                            request.getSubtaskIndex()));
+                        }
+                        if (subtask0ReceivedCounter.get() == subtask0Count
+                                && subtask1ReceivedCounter.get() == subtask1Count) {
+                            allReceivedFuture.complete(null);
+                        }
+                    }
+                };
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(throwingRequestProcessor, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, subtaskIndex0);
+        worker.registerSubtask(JOB_VERTEX_ID, subtaskIndex1);
+        try {
+            worker.start();
+            // start the first batch
+            for (ChannelStateWriteRequest request : firstBatchRequests) {
+                worker.submit(request);
+            }
+            firstBatchFuture.get();
+            assertThat(subtask0ReceivedCounter.get()).isOne();
+            assertThat(subtask1ReceivedCounter.get()).isEqualTo(firstBatchSubtask1Count);
+
+            // start the second batch
+            for (ChannelStateWriteRequest request : secondBatchRequests) {
+                worker.submit(request);
+            }
+            dataFuture.complete(Collections.emptyList());
+            allReceivedFuture.get();
+            assertThat(subtask0ReceivedCounter.get()).isEqualTo(subtask0Count);
+            assertThat(subtask1ReceivedCounter.get()).isEqualTo(subtask1Count);
+        } finally {
+            worker.releaseSubtask(JOB_VERTEX_ID, subtaskIndex0);
+            worker.releaseSubtask(JOB_VERTEX_ID, subtaskIndex1);
         }
     }
 
@@ -176,14 +289,17 @@ class ChannelStateWriteRequestExecutorImplTest {
                         throw testException;
                     }
                 };
-        LinkedBlockingDeque<ChannelStateWriteRequest> deque =
-                new LinkedBlockingDeque<>(Collections.singletonList(new TestWriteRequest()));
+        Deque<ChannelStateWriteRequest> deque =
+                new ArrayDeque<>(
+                        Collections.singletonList(
+                                new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX)));
         ChannelStateWriteRequestExecutorImpl worker =
                 new ChannelStateWriteRequestExecutorImpl(
-                        TASK_NAME, throwingRequestProcessor, deque);
+                        throwingRequestProcessor, deque, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         worker.run();
         try {
-            worker.close();
+            worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         } catch (IOException e) {
             if (findThrowable(e, TestException.class)
                     .filter(found -> found == testException)
@@ -196,12 +312,136 @@ class ChannelStateWriteRequestExecutorImplTest {
         fail("exception not thrown");
     }
 
-    private static class TestWriteRequest implements ChannelStateWriteRequest {
+    @Test
+    void testSubmitRequestOfUnregisteredSubtask() throws Exception {
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(NO_OP, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        worker.start();
+        worker.submit(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
+
+        assertThatThrownBy(
+                        () -> worker.submit(new TestWriteRequest(new JobVertexID(), SUBTASK_INDEX)))
+                .isInstanceOf(IllegalArgumentException.class)
+                .hasMessageContaining("is not yet registered.");
+
+        assertThatThrownBy(
+                        () ->
+                                worker.submitPriority(
+                                        new TestWriteRequest(new JobVertexID(), SUBTASK_INDEX)))
+                .isInstanceOf(IllegalArgumentException.class)
+                .hasMessageContaining("is not yet registered.");
+    }
+
+    @Test
+    void testSubmitPriorityUnreadyRequest() throws Exception {
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(NO_OP, 5, e -> {});
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        worker.start();
+        worker.submitPriority(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
+
+        assertThatThrownBy(
+                        () ->
+                                worker.submitPriority(
+                                        new TestWriteRequest(
+                                                JOB_VERTEX_ID,
+                                                SUBTASK_INDEX,
+                                                new CompletableFuture<>())))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("The priority request must be ready.");
+    }
+
+    @Test
+    void testRegisterSubtaskAfterRegisterCompleted() throws Exception {
+        int maxSubtasksPerChannelStateFile = 5;
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(
+                        NO_OP, maxSubtasksPerChannelStateFile, e -> {});
+        for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+            assertThat(worker.isRegistering()).isTrue();
+            worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX);
+        }
+        assertThat(worker.isRegistering()).isFalse();
+        assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("This executor has been registered.");
+    }
+
+    @Test
+    void testSubmitStartRequestBeforeRegisterCompleted() throws Exception {
+        CompletableFuture<Void> dispatcherFuture = new CompletableFuture<>();
+        TestRequestDispatcher dispatcher =
+                new TestRequestDispatcher() {
+                    @Override
+                    public void dispatch(ChannelStateWriteRequest request) {
+                        if (request instanceof CheckpointStartRequest) {
+                            dispatcherFuture.complete(null);
+                        }
+                    }
+                };
+        int maxSubtasksPerChannelStateFile = 5;
+        CompletableFuture<ChannelStateWriteRequestExecutor> workerFuture =
+                new CompletableFuture<>();
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(
+                        dispatcher, maxSubtasksPerChannelStateFile, workerFuture::complete);
+        worker.start();
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        assertThat(worker.isRegistering()).isTrue();
+
+        worker.submit(
+                ChannelStateWriteRequest.start(
+                        JOB_VERTEX_ID,
+                        SUBTASK_INDEX,
+                        1,
+                        new ChannelStateWriter.ChannelStateWriteResult(),
+                        CheckpointStorageLocationReference.getDefault()));
+        dispatcherFuture.get();
+        assertThat(worker.isRegistering()).isFalse();
+        assertThat(workerFuture).isCompletedWithValue(worker);
+        assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("This executor has been registered.");
+    }
+
+    @Test
+    void testReleaseSubtaskBeforeRegisterCompleted() throws Exception {
+        int maxSubtasksPerChannelStateFile = 5;
+        CompletableFuture<ChannelStateWriteRequestExecutor> workerFuture =
+                new CompletableFuture<>();
+        ChannelStateWriteRequestExecutorImpl worker =
+                new ChannelStateWriteRequestExecutorImpl(
+                        new TestRequestDispatcher(),
+                        maxSubtasksPerChannelStateFile,
+                        workerFuture::complete);
+        worker.start();
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        assertThat(worker.isRegistering()).isTrue();
+
+        worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+        assertThat(worker.isRegistering()).isFalse();
+        assertThat(workerFuture).isCompletedWithValue(worker);
+        assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+                .isInstanceOf(IllegalStateException.class)
+                .hasMessage("This executor has been registered.");
+    }
+
+    private static class TestWriteRequest extends ChannelStateWriteRequest {
         private boolean cancelled = false;
 
-        @Override
-        public long getCheckpointId() {
-            return 0;
+        @Nullable private final CompletableFuture<?> readyFuture;
+
+        public TestWriteRequest(JobVertexID jobVertexID, int subtaskIndex) {
+            this(jobVertexID, subtaskIndex, null);
+        }
+
+        public TestWriteRequest(
+                JobVertexID jobVertexID,
+                int subtaskIndex,
+                @Nullable CompletableFuture<?> readyFuture) {
+            super(jobVertexID, subtaskIndex, 0, "Test");
+            this.readyFuture = readyFuture;
         }
 
         @Override
@@ -212,27 +452,35 @@ class ChannelStateWriteRequestExecutorImplTest {
         public boolean isCancelled() {
             return cancelled;
         }
+
+        @Override
+        public CompletableFuture<?> getReadyFuture() {
+            if (readyFuture != null) {
+                return readyFuture;
+            }
+            return super.getReadyFuture();
+        }
     }
 
-    private static class WorkerClosingDeque extends LinkedBlockingDeque<ChannelStateWriteRequest> {
+    private static class WorkerClosingDeque extends ArrayDeque<ChannelStateWriteRequest> {
         private ChannelStateWriteRequestExecutor worker;
 
         @Override
-        public void put(@Nonnull ChannelStateWriteRequest request) throws InterruptedException {
-            super.putFirst(request);
+        public boolean add(@Nonnull ChannelStateWriteRequest request) {
+            boolean add = super.add(request);
             try {
-                worker.close();
+                worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
             } catch (IOException e) {
                 ExceptionUtils.rethrow(e);
             }
+            return add;
         }
 
         @Override
-        public void putFirst(@Nonnull ChannelStateWriteRequest request)
-                throws InterruptedException {
-            super.putFirst(request);
+        public void addFirst(@Nonnull ChannelStateWriteRequest request) {
+            super.addFirst(request);
             try {
-                worker.close();
+                worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
             } catch (IOException e) {
                 ExceptionUtils.rethrow(e);
             }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java
new file mode 100644
index 00000000000..ee6a51a684f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
+
+import java.util.Collection;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
+
+public class ChannelStateWriteResultUtil {
+
+    public static void assertHasSpecialCause(
+            ChannelStateWriter.ChannelStateWriteResult result, Class<? extends Throwable> type) {
+        assertThatThrownBy(() -> result.getInputChannelStateHandles().get())
+                .hasCauseInstanceOf(type);
+        assertThatThrownBy(() -> result.getResultSubpartitionStateHandles().get())
+                .hasCauseInstanceOf(type);
+    }
+
+    public static void assertCheckpointFailureReason(
+            ChannelStateWriter.ChannelStateWriteResult result,
+            CheckpointFailureReason checkpointFailureReason) {
+        assertThatThrownBy(() -> result.getInputChannelStateHandles().get())
+                .cause()
+                .asInstanceOf(type(CheckpointException.class))
+                .satisfies(
+                        checkpointException ->
+                                assertThat(checkpointException.getCheckpointFailureReason())
+                                        .isEqualTo(checkpointFailureReason));
+
+        assertThatThrownBy(() -> result.getResultSubpartitionStateHandles().get())
+                .cause()
+                .asInstanceOf(type(CheckpointException.class))
+                .satisfies(
+                        checkpointException ->
+                                assertThat(checkpointException.getCheckpointFailureReason())
+                                        .isEqualTo(checkpointFailureReason));
+    }
+
+    public static void assertAllSubtaskDoneNormally(
+            Collection<ChannelStateWriter.ChannelStateWriteResult> results) {
+        assertThat(results)
+                .allMatch(ChannelStateWriter.ChannelStateWriteResult::isDone)
+                .allMatch(
+                        result -> !result.getInputChannelStateHandles().isCompletedExceptionally())
+                .allMatch(
+                        result ->
+                                !result.getResultSubpartitionStateHandles()
+                                        .isCompletedExceptionally());
+    }
+
+    public static void assertAllSubtaskNotDone(
+            Collection<ChannelStateWriter.ChannelStateWriteResult> results) {
+        assertThat(results).allMatch(result -> !result.isDone());
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index fcf9ad65f2a..0b2cbe3cc6c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -17,12 +17,15 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.util.function.BiConsumerWithException;
 
 import org.junit.jupiter.api.Test;
@@ -33,7 +36,6 @@ import java.util.Deque;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
 
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -42,6 +44,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 class ChannelStateWriterImplTest {
     private static final long CHECKPOINT_ID = 42L;
     private static final String TASK_NAME = "test";
+    private static final JobID JOB_ID = new JobID();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
 
     @Test
     void testAddEventBuffer() throws Exception {
@@ -141,17 +146,20 @@ class ChannelStateWriterImplTest {
     }
 
     @Test
-    void testBuffersRecycledOnError() throws IOException {
+    void testBuffersRecycledOnError() {
         NetworkBuffer buffer = getBuffer();
-        try (ChannelStateWriterImpl writer =
+        ChannelStateWriterImpl writer =
                 new ChannelStateWriterImpl(
-                        TASK_NAME, new ConcurrentHashMap<>(), failingWorker(), 5)) {
-            writer.open();
-            assertThatThrownBy(() -> callAddInputData(writer, buffer))
-                    .isInstanceOf(RuntimeException.class)
-                    .hasCauseInstanceOf(TestException.class);
-            assertThat(buffer.isRecycled()).isTrue();
-        }
+                        JOB_VERTEX_ID,
+                        TASK_NAME,
+                        SUBTASK_INDEX,
+                        new ConcurrentHashMap<>(),
+                        failingWorker(),
+                        5);
+        assertThatThrownBy(() -> callAddInputData(writer, buffer))
+                .isInstanceOf(RuntimeException.class)
+                .hasCauseInstanceOf(TestException.class);
+        assertThat(buffer.isRecycled()).isTrue();
     }
 
     @Test
@@ -210,10 +218,17 @@ class ChannelStateWriterImplTest {
 
     @Test
     void testRethrowOnNextCall() {
-        SyncChannelStateWriteRequestExecutor worker = new SyncChannelStateWriteRequestExecutor();
+        SyncChannelStateWriteRequestExecutor worker =
+                new SyncChannelStateWriteRequestExecutor(JOB_ID);
         ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl(TASK_NAME, new ConcurrentHashMap<>(), worker, 5);
-        writer.open();
+                new ChannelStateWriterImpl(
+                        JOB_VERTEX_ID,
+                        TASK_NAME,
+                        SUBTASK_INDEX,
+                        new ConcurrentHashMap<>(),
+                        worker,
+                        5);
+        worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         worker.setThrown(new TestException());
         assertThatThrownBy(() -> callStart(writer)).hasCauseInstanceOf(TestException.class);
     }
@@ -223,8 +238,13 @@ class ChannelStateWriterImplTest {
         int maxCheckpoints = 3;
         try (ChannelStateWriterImpl writer =
                 new ChannelStateWriterImpl(
-                        TASK_NAME, 0, getStreamFactoryFactory(), maxCheckpoints)) {
-            writer.open();
+                        JOB_VERTEX_ID,
+                        TASK_NAME,
+                        SUBTASK_INDEX,
+                        new JobManagerCheckpointStorage(),
+                        maxCheckpoints,
+                        new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+                        5)) {
             for (int i = 0; i < maxCheckpoints; i++) {
                 writer.start(i, CheckpointOptions.forCheckpointWithDefaultLocation());
             }
@@ -237,15 +257,6 @@ class ChannelStateWriterImplTest {
         }
     }
 
-    @Test
-    void testStartNotOpened() throws IOException {
-        try (ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory())) {
-            assertThatThrownBy(() -> callStart(writer))
-                    .hasCauseInstanceOf(IllegalStateException.class);
-        }
-    }
-
     @Test
     void testNoStartAfterClose() throws IOException {
         ChannelStateWriterImpl writer = openWriter();
@@ -274,8 +285,6 @@ class ChannelStateWriterImplTest {
 
     private ChannelStateWriteRequestExecutor failingWorker() {
         return new ChannelStateWriteRequestExecutor() {
-            @Override
-            public void close() {}
 
             @Override
             public void submit(ChannelStateWriteRequest e) {
@@ -289,6 +298,12 @@ class ChannelStateWriterImplTest {
 
             @Override
             public void start() throws IllegalStateException {}
+
+            @Override
+            public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {}
+
+            @Override
+            public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {}
         };
     }
 
@@ -306,21 +321,31 @@ class ChannelStateWriterImplTest {
                             ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>
                     testFn)
             throws Exception {
-        try (SyncChannelStateWriteRequestExecutor worker =
-                        new SyncChannelStateWriteRequestExecutor();
-                ChannelStateWriterImpl writer =
-                        new ChannelStateWriterImpl(
-                                TASK_NAME, new ConcurrentHashMap<>(), worker, 5)) {
-            writer.open();
+        SyncChannelStateWriteRequestExecutor worker =
+                new SyncChannelStateWriteRequestExecutor(JOB_ID);
+        try (ChannelStateWriterImpl writer =
+                new ChannelStateWriterImpl(
+                        JOB_VERTEX_ID,
+                        TASK_NAME,
+                        SUBTASK_INDEX,
+                        new ConcurrentHashMap<>(),
+                        worker,
+                        5)) {
+            worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
             testFn.accept(writer, worker);
+        } finally {
+            worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
         }
     }
 
     private ChannelStateWriterImpl openWriter() {
-        ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory());
-        writer.open();
-        return writer;
+        return new ChannelStateWriterImpl(
+                JOB_VERTEX_ID,
+                TASK_NAME,
+                SUBTASK_INDEX,
+                new JobManagerCheckpointStorage(),
+                new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+                5);
     }
 
     private void callStart(ChannelStateWriter writer) {
@@ -352,14 +377,11 @@ class SyncChannelStateWriteRequestExecutor implements ChannelStateWriteRequestEx
     private final Deque<ChannelStateWriteRequest> deque;
     private Exception thrown;
 
-    SyncChannelStateWriteRequestExecutor() {
+    SyncChannelStateWriteRequestExecutor(JobID jobID) {
         deque = new ArrayDeque<>();
         requestProcessor =
                 new ChannelStateWriteRequestDispatcherImpl(
-                        "dummy task",
-                        0,
-                        getStreamFactoryFactory(),
-                        new ChannelStateSerializerImpl());
+                        new JobManagerCheckpointStorage(), jobID, new ChannelStateSerializerImpl());
     }
 
     @Override
@@ -382,7 +404,14 @@ class SyncChannelStateWriteRequestExecutor implements ChannelStateWriteRequestEx
     public void start() throws IllegalStateException {}
 
     @Override
-    public void close() {}
+    public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+        deque.add(ChannelStateWriteRequest.registerSubtask(jobVertexID, subtaskIndex));
+    }
+
+    @Override
+    public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+        deque.add(ChannelStateWriteRequest.releaseSubtask(jobVertexID, subtaskIndex));
+    }
 
     void processAllRequests() throws Exception {
         while (!deque.isEmpty()) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index 12a87c14183..0d4fe2b14e5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
 import org.junit.jupiter.api.Test;
 
 import java.util.concurrent.CyclicBarrier;
@@ -62,6 +64,8 @@ class CheckpointInProgressRequestTest {
             AtomicInteger cancelCounter, CyclicBarrier cb) {
         return new CheckpointInProgressRequest(
                 "test",
+                new JobVertexID(),
+                0,
                 1L,
                 unused -> {},
                 unused -> {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 260e83331bb..ef35e7a9218 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -73,6 +74,8 @@ public class DummyEnvironment implements Environment {
     private final AccumulatorRegistry accumulatorRegistry;
     private UserCodeClassLoader userClassLoader;
     private final Configuration taskConfiguration = new Configuration();
+    private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory =
+            new ChannelStateWriteRequestExecutorFactory(jobId);
 
     public DummyEnvironment() {
         this("Test Job", 1, 0, 1);
@@ -268,4 +271,9 @@ public class DummyEnvironment implements Environment {
     public TaskOperatorEventGateway getOperatorCoordinatorEventGateway() {
         return new NoOpTaskOperatorEventGateway();
     }
+
+    @Override
+    public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+        return channelStateExecutorFactory;
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index a5a53780b6c..f87b9d417f4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -135,6 +136,8 @@ public class MockEnvironment implements Environment, AutoCloseable {
 
     private CheckpointStorageAccess checkpointStorageAccess;
 
+    private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
     public static MockEnvironmentBuilder builder() {
         return new MockEnvironmentBuilder();
     }
@@ -157,7 +160,8 @@ public class MockEnvironment implements Environment, AutoCloseable {
             TaskMetricGroup taskMetricGroup,
             TaskManagerRuntimeInfo taskManagerRuntimeInfo,
             MemoryManager memManager,
-            ExternalResourceInfoProvider externalResourceInfoProvider) {
+            ExternalResourceInfoProvider externalResourceInfoProvider,
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
 
         this.jobID = jobID;
         this.jobVertexID = jobVertexID;
@@ -195,6 +199,7 @@ public class MockEnvironment implements Environment, AutoCloseable {
         this.mainMailboxExecutor = new SyncMailboxExecutor();
 
         this.asyncOperationsThreadPool = Executors.newDirectExecutorService();
+        this.channelStateExecutorFactory = channelStateExecutorFactory;
     }
 
     public IteratorWrappingTestSingleInputGate<Record> addInput(
@@ -441,6 +446,11 @@ public class MockEnvironment implements Environment, AutoCloseable {
         return checkNotNull(checkpointStorageAccess);
     }
 
+    @Override
+    public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+        return channelStateExecutorFactory;
+    }
+
     public void setExpectedExternalFailureCause(Class<? extends Throwable> expectedThrowableClass) {
         this.expectedExternalFailureCause = Optional.of(expectedThrowableClass);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
index d3e723d05bf..ec61b9a6fa5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.operators.testutils;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
 import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
@@ -61,6 +62,8 @@ public class MockEnvironmentBuilder {
             buildMemoryManager(1024 * MemoryManager.DEFAULT_PAGE_SIZE);
     private ExternalResourceInfoProvider externalResourceInfoProvider =
             ExternalResourceInfoProvider.NO_EXTERNAL_RESOURCES;
+    private ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory =
+            new ChannelStateWriteRequestExecutorFactory(jobID);
 
     private MemoryManager buildMemoryManager(long memorySize) {
         return MemoryManagerBuilder.newBuilder().setMemorySize(memorySize).build();
@@ -159,6 +162,12 @@ public class MockEnvironmentBuilder {
         return this;
     }
 
+    public MockEnvironmentBuilder setChannelStateWriteRequestExecutorFactory(
+            ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
+        this.channelStateExecutorFactory = channelStateExecutorFactory;
+        return this;
+    }
+
     public MockEnvironment build() {
         if (ioManager == null) {
             ioManager = new IOManagerAsync();
@@ -181,6 +190,7 @@ public class MockEnvironmentBuilder {
                 taskMetricGroup,
                 taskManagerRuntimeInfo,
                 memoryManager,
-                externalResourceInfoProvider);
+                externalResourceInfoProvider,
+                channelStateExecutorFactory);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
index 7f1518f529c..fe205a345ea 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
@@ -17,11 +17,13 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.api.common.JobID;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
@@ -44,8 +46,9 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilde
 import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
 import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.util.function.SupplierWithException;
 
 import org.junit.Test;
@@ -75,6 +78,9 @@ import static org.junit.Assert.assertNull;
 /** ChannelPersistenceITCase. */
 public class ChannelPersistenceITCase {
     private static final Random RANDOM = new Random(System.currentTimeMillis());
+    private static final JobID JOB_ID = new JobID();
+    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+    private static final int SUBTASK_INDEX = 0;
 
     @Test
     public void testUpstreamBlocksAfterRecoveringState() throws Exception {
@@ -250,8 +256,13 @@ public class ChannelPersistenceITCase {
         Map<ResultSubpartitionInfo, Buffer> rsBuffers = wrapWithBuffers(rsMap);
         Map<ResultSubpartitionInfo, Buffer> rsFutureBuffers = wrapWithBuffers(rsFutureMap);
         try (ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl("test", 0, getStreamFactoryFactory(maxStateSize))) {
-            writer.open();
+                new ChannelStateWriterImpl(
+                        JOB_VERTEX_ID,
+                        "test",
+                        SUBTASK_INDEX,
+                        new JobManagerCheckpointStorage(maxStateSize),
+                        new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+                        5)) {
             writer.start(
                     checkpointId,
                     new CheckpointOptions(
@@ -283,30 +294,6 @@ public class ChannelPersistenceITCase {
         }
     }
 
-    public static CheckpointStorageWorkerView getStreamFactoryFactory() {
-        return getStreamFactoryFactory(42);
-    }
-
-    public static CheckpointStorageWorkerView getStreamFactoryFactory(int maxStateSize) {
-        return new CheckpointStorageWorkerView() {
-            @Override
-            public CheckpointStreamFactory resolveCheckpointStorageLocation(
-                    long checkpointId, CheckpointStorageLocationReference reference) {
-                return new NonPersistentMetadataCheckpointStorageLocation(maxStateSize);
-            }
-
-            @Override
-            public CheckpointStateOutputStream createTaskOwnedStateStream() {
-                throw new UnsupportedOperationException();
-            }
-
-            @Override
-            public CheckpointStateToolset createTaskOwnedCheckpointStateToolset() {
-                throw new UnsupportedOperationException();
-            }
-        };
-    }
-
     private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriteResult t) throws Exception {
         return new TaskStateSnapshot(
                 singletonMap(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java
new file mode 100644
index 00000000000..32b2e28528e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TaskExecutorChannelStateExecutorFactoryManager}. */
+public class TaskExecutorChannelStateExecutorFactoryManagerTest {
+
+    @Test
+    void testReuseFactory() {
+        TaskExecutorChannelStateExecutorFactoryManager manager =
+                new TaskExecutorChannelStateExecutorFactoryManager();
+
+        JobID jobID = new JobID();
+        ChannelStateWriteRequestExecutorFactory factory = manager.getOrCreateExecutorFactory(jobID);
+        assertThat(manager.getOrCreateExecutorFactory(jobID))
+                .as("Same job should share the executor factory.")
+                .isSameAs(factory);
+
+        assertThat(manager.getOrCreateExecutorFactory(new JobID()))
+                .as("Different jobs cannot share executor factory.")
+                .isNotSameAs(factory);
+        manager.shutdown();
+    }
+
+    @Test
+    void testReleaseForJob() {
+        TaskExecutorChannelStateExecutorFactoryManager manager =
+                new TaskExecutorChannelStateExecutorFactoryManager();
+
+        JobID jobID = new JobID();
+        assertThat(manager.getFactoryByJobId(jobID)).isNull();
+        manager.getOrCreateExecutorFactory(jobID);
+        assertThat(manager.getFactoryByJobId(jobID)).isNotNull();
+
+        manager.releaseResourcesForJob(jobID);
+        assertThat(manager.getFactoryByJobId(jobID)).isNull();
+        manager.shutdown();
+    }
+
+    @Test
+    void testShutdown() {
+        TaskExecutorChannelStateExecutorFactoryManager manager =
+                new TaskExecutorChannelStateExecutorFactoryManager();
+
+        JobID jobID = new JobID();
+        manager.getOrCreateExecutorFactory(jobID);
+        manager.shutdown();
+
+        assertThatThrownBy(() -> manager.getOrCreateExecutorFactory(jobID))
+                .isInstanceOf(IllegalStateException.class);
+        assertThatThrownBy(() -> manager.getOrCreateExecutorFactory(new JobID()))
+                .isInstanceOf(IllegalStateException.class);
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
index cfbf6548dd5..93671d14d4d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.memory.SharedResources;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.registration.RetryingRegistrationConfiguration;
 import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
 import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
 import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
 import org.apache.flink.runtime.taskexecutor.slot.NoOpSlotAllocationSnapshotPersistenceService;
@@ -58,6 +59,7 @@ public class TaskManagerServicesBuilder {
     private JobLeaderService jobLeaderService;
     private TaskExecutorLocalStateStoresManager taskStateManager;
     private TaskExecutorStateChangelogStoragesManager taskChangelogStoragesManager;
+    private TaskExecutorChannelStateExecutorFactoryManager taskChannelStateExecutorFactoryManager;
     private TaskEventDispatcher taskEventDispatcher;
     private LibraryCacheManager libraryCacheManager;
     private SharedResources sharedResources;
@@ -82,6 +84,8 @@ public class TaskManagerServicesBuilder {
                         RetryingRegistrationConfiguration.defaultConfiguration());
         taskStateManager = mock(TaskExecutorLocalStateStoresManager.class);
         taskChangelogStoragesManager = mock(TaskExecutorStateChangelogStoragesManager.class);
+        taskChannelStateExecutorFactoryManager =
+                new TaskExecutorChannelStateExecutorFactoryManager();
         libraryCacheManager = TestingLibraryCacheManager.newBuilder().build();
         managedMemorySize = MemoryManager.MIN_PAGE_SIZE;
         this.slotAllocationSnapshotPersistenceService =
@@ -144,6 +148,12 @@ public class TaskManagerServicesBuilder {
         return this;
     }
 
+    public TaskManagerServicesBuilder setTaskChannelStateExecutorFactoryManager(
+            TaskExecutorChannelStateExecutorFactoryManager taskChannelStateExecutorFactoryManager) {
+        this.taskChannelStateExecutorFactoryManager = taskChannelStateExecutorFactoryManager;
+        return this;
+    }
+
     public TaskManagerServicesBuilder setLibraryCacheManager(
             LibraryCacheManager libraryCacheManager) {
         this.libraryCacheManager = libraryCacheManager;
@@ -174,6 +184,7 @@ public class TaskManagerServicesBuilder {
                 jobLeaderService,
                 taskStateManager,
                 taskChangelogStoragesManager,
+                taskChannelStateExecutorFactoryManager,
                 taskEventDispatcher,
                 Executors.newSingleThreadScheduledExecutor(),
                 libraryCacheManager,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 1eb55ab0c04..cfefb0459b1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -218,7 +219,8 @@ public class TaskAsyncCallTest extends TestLogger {
                 new TestingTaskManagerRuntimeInfo(),
                 taskMetricGroup,
                 partitionProducerStateChecker,
-                executor);
+                executor,
+                new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
     }
 
     /** Invokable for testing checkpoints. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
index a2559c5e650..127600a3cbc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.blob.PermanentBlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -225,7 +226,8 @@ public final class TestTaskBuilder {
                 new TestingTaskManagerRuntimeInfo(taskManagerConfig),
                 taskMetricGroup,
                 partitionProducerStateChecker,
-                executor);
+                executor,
+                new ChannelStateWriteRequestExecutorFactory(jobId));
     }
 
     public static void setTaskState(Task task, ExecutionState state) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
index 18c026105e5..dadac718a67 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.blob.VoidPermanentBlobService;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -254,7 +255,9 @@ public class JvmExitOnFatalErrorTest extends TestLogger {
                                 tmInfo,
                                 UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
                                 new NoOpPartitionProducerStateChecker(),
-                                executor);
+                                executor,
+                                new ChannelStateWriteRequestExecutorFactory(
+                                        jobInformation.getJobId()));
 
                 System.err.println("starting task thread");
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
index 7e59e0874c9..b12ce559c91 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
@@ -604,6 +604,29 @@ public class CheckpointConfig implements java.io.Serializable {
                 ExecutionCheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT, alignedCheckpointTimeout);
     }
 
+    /**
+     * @return the number of subtasks to share the same channel state file, as configured via {@link
+     *     #setMaxSubtasksPerChannelStateFile(int)} or {@link
+     *     ExecutionCheckpointingOptions#UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE}.
+     */
+    @PublicEvolving
+    public int getMaxSubtasksPerChannelStateFile() {
+        return configuration.get(
+                ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE);
+    }
+
+    /**
+     * The number of subtasks to share the same channel state file. If {@link
+     * ExecutionCheckpointingOptions#UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE} has value equal
+     * to <code>1</code>, each subtask will create a new channel state file.
+     */
+    @PublicEvolving
+    public void setMaxSubtasksPerChannelStateFile(int maxSubtasksPerChannelStateFile) {
+        configuration.set(
+                ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+                maxSubtasksPerChannelStateFile);
+    }
+
     /**
      * Returns whether approximate local recovery is enabled.
      *
@@ -841,6 +864,10 @@ public class CheckpointConfig implements java.io.Serializable {
         configuration
                 .getOptional(ExecutionCheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT)
                 .ifPresent(this::setAlignedCheckpointTimeout);
+        configuration
+                .getOptional(
+                        ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE)
+                .ifPresent(this::setMaxSubtasksPerChannelStateFile);
         configuration
                 .getOptional(ExecutionCheckpointingOptions.FORCE_UNALIGNED)
                 .ifPresent(this::setForceUnalignedCheckpoints);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
index 02ebbd44e82..f13236ae049 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
@@ -280,4 +280,13 @@ public class ExecutionCheckpointingOptions {
                     .booleanType()
                     .defaultValue(false)
                     .withDescription("Flag to enable approximate local recovery.");
+
+    public static final ConfigOption<Integer> UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE =
+            key("execution.checkpointing.unaligned.max-subtasks-per-channel-state-file")
+                    .intType()
+                    .defaultValue(5)
+                    .withDescription(
+                            "Defines the maximum number of subtasks that share the same channel state file. "
+                                    + "It can reduce the number of small files when enable unaligned checkpoint. "
+                                    + "Each subtask will create a new channel state file when this is configured to 1.");
 }
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
index 476c165c8c0..267289c181f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
@@ -530,6 +530,17 @@ public class StreamConfig implements Serializable {
                 ExecutionCheckpointingOptions.MAX_CONCURRENT_CHECKPOINTS.defaultValue());
     }
 
+    public int getMaxSubtasksPerChannelStateFile() {
+        return config.get(
+                ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE);
+    }
+
+    public void setMaxSubtasksPerChannelStateFile(int maxSubtasksPerChannelStateFile) {
+        config.set(
+                ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+                maxSubtasksPerChannelStateFile);
+    }
+
     /**
      * Sets the job vertex level non-chained outputs. The given output list must have the same order
      * with {@link JobVertex#getProducedDataSets()}.
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 5f96de65898..5161a2586bf 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -1009,6 +1009,7 @@ public class StreamingJobGraphGenerator {
         config.setCheckpointMode(getCheckpointingMode(checkpointCfg));
         config.setUnalignedCheckpointsEnabled(checkpointCfg.isUnalignedCheckpointsEnabled());
         config.setAlignedCheckpointTimeout(checkpointCfg.getAlignedCheckpointTimeout());
+        config.setMaxSubtasksPerChannelStateFile(checkpointCfg.getMaxSubtasksPerChannelStateFile());
         config.setMaxConcurrentCheckpoints(checkpointCfg.getMaxConcurrentCheckpoints());
 
         for (int i = 0; i < vertex.getStatePartitioners().length; i++) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 40ad6a1fceb..ee10e7929e3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -456,6 +456,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
             this.subtaskCheckpointCoordinator =
                     new SubtaskCheckpointCoordinatorImpl(
+                            checkpointStorage,
                             checkpointStorageAccess,
                             getName(),
                             actionExecutor,
@@ -471,7 +472,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                             this::prepareInputSnapshot,
                             configuration.getMaxConcurrentCheckpoints(),
                             BarrierAlignmentUtil.createRegisterTimerCallback(
-                                    mainMailboxExecutor, systemTimerService));
+                                    mainMailboxExecutor, systemTimerService),
+                            configuration.getMaxSubtasksPerChannelStateFile());
             resourceCloser.registerCloseable(subtaskCheckpointCoordinator::close);
 
             // Register to stop all timers and threads. Should be closed first.
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
index 82b6e7f8093..c12541b7f6d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStateOutputStream;
 import org.apache.flink.runtime.state.CheckpointStateToolset;
+import org.apache.flink.runtime.state.CheckpointStorage;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -125,7 +126,8 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
     private long alignmentCheckpointId;
 
     SubtaskCheckpointCoordinatorImpl(
-            CheckpointStorageWorkerView checkpointStorage,
+            CheckpointStorage checkpointStorage,
+            CheckpointStorageWorkerView checkpointStorageView,
             String taskName,
             StreamTaskActionExecutor actionExecutor,
             ExecutorService asyncOperationsThreadPool,
@@ -137,9 +139,10 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
                             ChannelStateWriter, Long, CompletableFuture<Void>, CheckpointException>
                     prepareInputSnapshot,
             int maxRecordAbortedCheckpoints,
-            DelayableTimer registerTimer) {
+            DelayableTimer registerTimer,
+            int maxSubtasksPerChannelStateFile) {
         this(
-                checkpointStorage,
+                checkpointStorageView,
                 taskName,
                 actionExecutor,
                 asyncOperationsThreadPool,
@@ -148,7 +151,8 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
                 prepareInputSnapshot,
                 maxRecordAbortedCheckpoints,
                 unalignedCheckpointEnabled
-                        ? openChannelStateWriter(taskName, checkpointStorage, env)
+                        ? openChannelStateWriter(
+                                taskName, checkpointStorage, env, maxSubtasksPerChannelStateFile)
                         : ChannelStateWriter.NO_OP,
                 enableCheckpointAfterTasksFinished,
                 registerTimer);
@@ -191,12 +195,17 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
     }
 
     private static ChannelStateWriter openChannelStateWriter(
-            String taskName, CheckpointStorageWorkerView checkpointStorage, Environment env) {
-        ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl(
-                        taskName, env.getTaskInfo().getIndexOfThisSubtask(), checkpointStorage);
-        writer.open();
-        return writer;
+            String taskName,
+            CheckpointStorage checkpointStorage,
+            Environment env,
+            int maxSubtasksPerChannelStateFile) {
+        return new ChannelStateWriterImpl(
+                env.getJobVertexId(),
+                taskName,
+                env.getTaskInfo().getIndexOfThisSubtask(),
+                checkpointStorage,
+                env.getChannelStateExecutorFactory(),
+                maxSubtasksPerChannelStateFile);
     }
 
     @Override
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 5ddaaa0229e..26361ae31ed 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -293,7 +294,8 @@ public class InterruptSensitiveRestoreTest {
                 new TestingTaskManagerRuntimeInfo(),
                 UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
                 mock(PartitionProducerStateChecker.class),
-                mock(Executor.class));
+                mock(Executor.class),
+                new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
     }
 
     // ------------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
index 2c1af6fb1ed..38b3e7dce02 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
@@ -22,8 +22,8 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
-import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
-import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
+import org.apache.flink.runtime.state.CheckpointStorage;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.runtime.taskmanager.AsyncExceptionHandler;
 import org.apache.flink.util.concurrent.Executors;
 import org.apache.flink.util.concurrent.FutureUtils;
@@ -38,7 +38,7 @@ import static org.apache.flink.streaming.runtime.tasks.StreamTaskActionExecutor.
 /** A mock builder to build {@link SubtaskCheckpointCoordinator}. */
 public class MockSubtaskCheckpointCoordinatorBuilder {
     private String taskName = "mock-task";
-    private CheckpointStorageWorkerView checkpointStorage;
+    private CheckpointStorage checkpointStorage;
     private Environment environment;
     private AsyncExceptionHandler asyncExceptionHandler;
     private StreamTaskActionExecutor actionExecutor = IMMEDIATE;
@@ -47,6 +47,7 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
                     ChannelStateWriter, Long, CompletableFuture<Void>, CheckpointException>
             prepareInputSnapshot = (channelStateWriter, aLong) -> FutureUtils.completedVoidFuture();
     private boolean unalignedCheckpointEnabled;
+    private int maxSubtasksPerChannelStateFile = 5;
     private int maxRecordAbortedCheckpoints = 10;
     private boolean enableCheckpointAfterTasksFinished = true;
 
@@ -86,14 +87,18 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
         return this;
     }
 
+    public MockSubtaskCheckpointCoordinatorBuilder setMaxSubtasksPerChannelStateFile(
+            int maxSubtasksPerChannelStateFile) {
+        this.maxSubtasksPerChannelStateFile = maxSubtasksPerChannelStateFile;
+        return this;
+    }
+
     SubtaskCheckpointCoordinator build() throws IOException {
         if (environment == null) {
             this.environment = MockEnvironment.builder().build();
         }
         if (checkpointStorage == null) {
-            this.checkpointStorage =
-                    new MemoryBackendCheckpointStorageAccess(
-                            environment.getJobID(), null, null, 1024);
+            this.checkpointStorage = new JobManagerCheckpointStorage();
         }
         if (asyncExceptionHandler == null) {
             this.asyncExceptionHandler = new NonHandleAsyncException();
@@ -101,6 +106,7 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
 
         return new SubtaskCheckpointCoordinatorImpl(
                 checkpointStorage,
+                checkpointStorage.createCheckpointStorage(environment.getJobID()),
                 taskName,
                 actionExecutor,
                 executorService,
@@ -110,7 +116,8 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
                 enableCheckpointAfterTasksFinished,
                 prepareInputSnapshot,
                 maxRecordAbortedCheckpoints,
-                (callable, duration) -> () -> {});
+                (callable, duration) -> () -> {},
+                maxSubtasksPerChannelStateFile);
     }
 
     private static class NonHandleAsyncException implements AsyncExceptionHandler {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 9d37bc16cf8..563ee0ef10d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -120,6 +121,8 @@ public class StreamMockEnvironment implements Environment {
 
     private final boolean collectNetworkEvents;
 
+    private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
     @Nullable private Consumer<Throwable> externalExceptionHandler;
 
     private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class);
@@ -196,6 +199,7 @@ public class StreamMockEnvironment implements Environment {
                 registry.createTaskRegistry(
                         jobID, executionAttemptID.getExecutionVertexId().getJobVertexId());
         this.collectNetworkEvents = collectNetworkEvents;
+        this.channelStateExecutorFactory = new ChannelStateWriteRequestExecutorFactory(jobID);
     }
 
     public StreamMockEnvironment(
@@ -415,4 +419,9 @@ public class StreamMockEnvironment implements Environment {
     public void setCheckpointResponder(CheckpointResponder checkpointResponder) {
         this.checkpointResponder = checkpointResponder;
     }
+
+    @Override
+    public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+        return channelStateExecutorFactory;
+    }
 }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
index 578cd953509..bf330729830 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.security.FlinkSecurityManager;
 import org.apache.flink.core.security.UserSystemExitException;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -205,7 +206,8 @@ public class StreamTaskSystemExitTest extends TestLogger {
                 taskManagerRuntimeInfo,
                 UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
                 mock(PartitionProducerStateChecker.class),
-                Executors.directExecutor());
+                Executors.directExecutor(),
+                new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
     }
 
     /** StreamTask emulating system exit behavior from different callback functions. */
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
index a81888a5ffc..5b08e50fe35 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -189,7 +190,8 @@ public class StreamTaskTerminationTest extends TestLogger {
                         taskManagerRuntimeInfo,
                         UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
                         mock(PartitionProducerStateChecker.class),
-                        Executors.directExecutor());
+                        Executors.directExecutor(),
+                        new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
 
         CompletableFuture<Void> taskRun =
                 CompletableFuture.runAsync(() -> task.run(), EXECUTOR_RESOURCE.getExecutor());
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
index 684568b6ce8..47354b1b6ab 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
@@ -51,6 +51,7 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.TestCheckpointStorageWorkerView;
 import org.apache.flink.runtime.state.TestTaskStateManager;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -84,7 +85,6 @@ import java.util.concurrent.atomic.AtomicReference;
 import java.util.function.Supplier;
 
 import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.junit.Assert.assertEquals;
@@ -577,8 +577,15 @@ public class SubtaskCheckpointCoordinatorTest {
     @Test
     public void testChannelStateWriteResultLeakAndNotFailAfterCheckpointAborted() throws Exception {
         String taskName = "test";
+        DummyEnvironment env = new DummyEnvironment();
         ChannelStateWriterImpl writer =
-                new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
+                new ChannelStateWriterImpl(
+                        env.getJobVertexId(),
+                        taskName,
+                        0,
+                        new JobManagerCheckpointStorage(),
+                        env.getChannelStateExecutorFactory(),
+                        5);
         try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
                 SubtaskCheckpointCoordinator coordinator =
                         new SubtaskCheckpointCoordinatorImpl(
@@ -586,14 +593,13 @@ public class SubtaskCheckpointCoordinatorTest {
                                 taskName,
                                 StreamTaskActionExecutor.IMMEDIATE,
                                 newDirectExecutorService(),
-                                new DummyEnvironment(),
+                                env,
                                 (unused1, unused2) -> {},
                                 (unused1, unused2) -> CompletableFuture.completedFuture(null),
                                 1,
                                 writer,
                                 true,
                                 (callable, duration) -> () -> {})) {
-            writer.open();
             final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
             int checkpointId = 1;
             // Abort checkpoint 1
@@ -629,23 +635,29 @@ public class SubtaskCheckpointCoordinatorTest {
         CheckpointOptions unalignedOptions =
                 CheckpointOptions.unaligned(
                         CHECKPOINT, CheckpointStorageLocationReference.getDefault());
+        DummyEnvironment env = new DummyEnvironment();
+        ChannelStateWriterImpl writer =
+                new ChannelStateWriterImpl(
+                        env.getJobVertexId(),
+                        taskName,
+                        0,
+                        new JobManagerCheckpointStorage(),
+                        env.getChannelStateExecutorFactory(),
+                        5);
         try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
-                ChannelStateWriterImpl writer =
-                        new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
                 SubtaskCheckpointCoordinator coordinator =
                         new SubtaskCheckpointCoordinatorImpl(
                                 new TestCheckpointStorageWorkerView(100),
                                 taskName,
                                 StreamTaskActionExecutor.IMMEDIATE,
                                 newDirectExecutorService(),
-                                new DummyEnvironment(),
+                                env,
                                 (unused1, unused2) -> {},
                                 (unused1, unused2) -> CompletableFuture.completedFuture(null),
                                 1,
                                 writer,
                                 true,
                                 (callable, duration) -> () -> {})) {
-            writer.open();
             final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
             int checkpoint42 = 42;
             int checkpoint43 = 43;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
index 0e18799404e..53ed047170c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.SavepointType;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -275,7 +276,8 @@ public class SynchronousCheckpointITCase {
                 new TestingTaskManagerRuntimeInfo(),
                 taskMetricGroup,
                 partitionProducerStateChecker,
-                executor);
+                executor,
+                new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
     }
 
     private static class TaskCleaner implements AutoCloseable {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
index f936b1f65db..bf52ca116bd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -244,7 +245,8 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
                 new TestingTaskManagerRuntimeInfo(),
                 UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
                 mock(PartitionProducerStateChecker.class),
-                Executors.directExecutor());
+                Executors.directExecutor(),
+                new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
     }
 
     // ------------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
index 283dd913ccb..54ece459aff 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
@@ -78,6 +78,7 @@ import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingO
 import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.CHECKPOINTING_MODE;
 import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.ENABLE_UNALIGNED;
 import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.EXTERNALIZED_CHECKPOINT;
+import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /** Tests caching of changelog segments downloaded during recovery. */
@@ -181,6 +182,9 @@ public class ChangelogRecoveryCachingITCase extends TestLogger {
 
         conf.set(ENABLE_UNALIGNED, true); // speedup
         conf.set(ALIGNED_CHECKPOINT_TIMEOUT, Duration.ZERO); // prevent randomization
+        conf.set(
+                UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+                1); // prevent file is opened multiple times
         conf.set(BUFFER_DEBLOAT_ENABLED, false); // prevent randomization
         conf.set(RESTART_STRATEGY, "none"); // not expecting any failures