You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2023/01/23 15:47:27 UTC
[flink] 03/03: [FLINK-26803][checkpoint] Merge the channel state files
This is an automated email from the ASF dual-hosted git repository.
pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 8be94e6663d8ac6e3d74bf4cd5f540cc96c8289e
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Tue Nov 8 12:14:15 2022 +0800
[FLINK-26803][checkpoint] Merge the channel state files
fixup! Address comments
fixup! fixup! Address comments
1. Add some job docs
2. Using the lock to refactor the ChannelStateWriteRequestExecutorImpl
fixup! fixup! fixup! Address comments
---
.../execution_checkpointing_configuration.html | 6 +
.../state/api/runtime/SavepointEnvironment.java | 9 +
.../checkpoint/CheckpointFailureManager.java | 1 +
.../checkpoint/CheckpointFailureReason.java | 4 +
.../channel/ChannelStateCheckpointWriter.java | 416 ++++++++++++---------
.../channel/ChannelStatePendingResult.java | 192 ++++++++++
.../channel/ChannelStateWriteRequest.java | 229 +++++++++---
.../ChannelStateWriteRequestDispatcherImpl.java | 176 ++++++---
.../channel/ChannelStateWriteRequestExecutor.java | 14 +-
.../ChannelStateWriteRequestExecutorFactory.java | 67 ++++
.../ChannelStateWriteRequestExecutorImpl.java | 311 +++++++++++++--
.../checkpoint/channel/ChannelStateWriterImpl.java | 85 +++--
.../flink/runtime/execution/Environment.java | 3 +
.../io/network/logger/NetworkActionsLogger.java | 5 +-
...ExecutorChannelStateExecutorFactoryManager.java | 106 ++++++
.../flink/runtime/taskexecutor/TaskExecutor.java | 14 +-
.../runtime/taskexecutor/TaskManagerServices.java | 12 +
.../runtime/taskmanager/RuntimeEnvironment.java | 12 +-
.../org/apache/flink/runtime/taskmanager/Task.java | 11 +-
.../channel/ChannelStateCheckpointWriterTest.java | 257 +++++++++++--
...ChannelStateWriteRequestDispatcherImplTest.java | 176 +++++++--
.../ChannelStateWriteRequestDispatcherTest.java | 35 +-
...hannelStateWriteRequestExecutorFactoryTest.java | 72 ++++
.../ChannelStateWriteRequestExecutorImplTest.java | 326 ++++++++++++++--
.../channel/ChannelStateWriteResultUtil.java | 75 ++++
.../channel/ChannelStateWriterImplTest.java | 113 +++---
.../channel/CheckpointInProgressRequestTest.java | 4 +
.../operators/testutils/DummyEnvironment.java | 8 +
.../operators/testutils/MockEnvironment.java | 12 +-
.../testutils/MockEnvironmentBuilder.java | 12 +-
.../runtime/state/ChannelPersistenceITCase.java | 41 +-
...utorChannelStateExecutorFactoryManagerTest.java | 78 ++++
.../taskexecutor/TaskManagerServicesBuilder.java | 11 +
.../runtime/taskmanager/TaskAsyncCallTest.java | 4 +-
.../flink/runtime/taskmanager/TestTaskBuilder.java | 4 +-
.../runtime/util/JvmExitOnFatalErrorTest.java | 5 +-
.../api/environment/CheckpointConfig.java | 27 ++
.../environment/ExecutionCheckpointingOptions.java | 9 +
.../flink/streaming/api/graph/StreamConfig.java | 11 +
.../api/graph/StreamingJobGraphGenerator.java | 1 +
.../flink/streaming/runtime/tasks/StreamTask.java | 4 +-
.../tasks/SubtaskCheckpointCoordinatorImpl.java | 29 +-
.../tasks/InterruptSensitiveRestoreTest.java | 4 +-
.../MockSubtaskCheckpointCoordinatorBuilder.java | 21 +-
.../runtime/tasks/StreamMockEnvironment.java | 9 +
.../runtime/tasks/StreamTaskSystemExitTest.java | 4 +-
.../runtime/tasks/StreamTaskTerminationTest.java | 4 +-
.../tasks/SubtaskCheckpointCoordinatorTest.java | 28 +-
.../runtime/tasks/SynchronousCheckpointITCase.java | 4 +-
.../tasks/TaskCheckpointingBehaviourTest.java | 4 +-
.../test/state/ChangelogRecoveryCachingITCase.java | 4 +
51 files changed, 2497 insertions(+), 572 deletions(-)
diff --git a/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html b/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
index 1889bd7e9f1..fcecef43791 100644
--- a/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
+++ b/docs/layouts/shortcodes/generated/execution_checkpointing_configuration.html
@@ -14,6 +14,12 @@
<td>Duration</td>
<td>Only relevant if <code class="highlighter-rouge">execution.checkpointing.unaligned.enabled</code> is enabled.<br /><br />If timeout is 0, checkpoints will always start unaligned.<br /><br />If timeout has a positive value, checkpoints will start aligned. If during checkpointing, checkpoint start delay exceeds this timeout, alignment will timeout and checkpoint barrier will start working as unaligned checkpoint.</td>
</tr>
+ <tr>
+ <td><h5>execution.checkpointing.unaligned.max-subtasks-per-channel-state-file</h5></td>
+ <td style="word-wrap: break-word;">5</td>
+ <td>Integer</td>
+ <td>Defines the maximum number of subtasks that share the same channel state file. It can reduce the number of small files when enable unaligned checkpoint. Each subtask will create a new channel state file when this is configured to 1.</td>
+ </tr>
<tr>
<td><h5>execution.checkpointing.checkpoints-after-tasks-finish.enabled</h5></td>
<td style="word-wrap: break-word;">true</td>
diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
index ee230c39d4b..73c17de4104 100644
--- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
+++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.PrioritizedOperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionGraphID;
@@ -107,6 +108,8 @@ public class SavepointEnvironment implements Environment {
private final UserCodeClassLoader userCodeClassLoader;
+ private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
private SavepointEnvironment(
RuntimeContext ctx,
Configuration configuration,
@@ -133,6 +136,7 @@ public class SavepointEnvironment implements Environment {
this.accumulatorRegistry = new AccumulatorRegistry(jobID, attemptID);
this.userCodeClassLoader = UserCodeClassLoaderRuntimeContextAdapter.from(ctx);
+ this.channelStateExecutorFactory = new ChannelStateWriteRequestExecutorFactory(jobID);
}
@Override
@@ -306,6 +310,11 @@ public class SavepointEnvironment implements Environment {
throw new UnsupportedOperationException(ERROR_MSG);
}
+ @Override
+ public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+ return channelStateExecutorFactory;
+ }
+
/** {@link SavepointEnvironment} builder. */
public static class Builder {
private RuntimeContext ctx;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
index 4b33a0aab1c..fabff241649 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureManager.java
@@ -228,6 +228,7 @@ public class CheckpointFailureManager {
case CHECKPOINT_SUBSUMED:
case CHECKPOINT_COORDINATOR_SUSPEND:
case CHECKPOINT_COORDINATOR_SHUTDOWN:
+ case CHANNEL_STATE_SHARED_STREAM_EXCEPTION:
case JOB_FAILOVER_REGION:
// for compatibility purposes with user job behavior
case CHECKPOINT_DECLINED_TASK_NOT_READY:
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
index e85b0130422..6f56bd67b6e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/CheckpointFailureReason.java
@@ -36,6 +36,10 @@ public enum CheckpointFailureReason {
CHECKPOINT_ASYNC_EXCEPTION(false, "Asynchronous task checkpoint failed."),
+ CHANNEL_STATE_SHARED_STREAM_EXCEPTION(
+ false,
+ "The checkpoint was aborted due to exception of other subtasks sharing the ChannelState file."),
+
CHECKPOINT_EXPIRED(false, "Checkpoint expired before completing."),
CHECKPOINT_SUBSUMED(false, "Checkpoint has been subsumed."),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
index c858896c5d5..4173bb7140e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriter.java
@@ -18,76 +18,72 @@
package org.apache.flink.runtime.checkpoint.channel;
import org.apache.flink.annotation.VisibleForTesting;
-import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.logger.NetworkActionsLogger;
-import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
-import org.apache.flink.runtime.state.InputChannelStateHandle;
-import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
import org.apache.flink.util.Preconditions;
import org.apache.flink.util.function.RunnableWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.annotation.Nonnull;
import javax.annotation.concurrent.NotThreadSafe;
import java.io.DataOutputStream;
import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Collection;
import java.util.HashMap;
-import java.util.List;
+import java.util.HashSet;
import java.util.Map;
-import java.util.Optional;
-import java.util.concurrent.CompletableFuture;
+import java.util.Objects;
+import java.util.Set;
-import static java.util.Collections.emptyList;
-import static java.util.Collections.singletonList;
-import static java.util.UUID.randomUUID;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
import static org.apache.flink.util.ExceptionUtils.findThrowable;
import static org.apache.flink.util.ExceptionUtils.rethrow;
+import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
-/** Writes channel state for a specific checkpoint-subtask-attempt triple. */
+/** Writes channel state for multiple subtasks of the same checkpoint. */
@NotThreadSafe
class ChannelStateCheckpointWriter {
private static final Logger LOG = LoggerFactory.getLogger(ChannelStateCheckpointWriter.class);
private final DataOutputStream dataStream;
private final CheckpointStateOutputStream checkpointStream;
- private final ChannelStateWriteResult result;
- private final Map<InputChannelInfo, StateContentMetaInfo> inputChannelOffsets = new HashMap<>();
- private final Map<ResultSubpartitionInfo, StateContentMetaInfo> resultSubpartitionOffsets =
- new HashMap<>();
+
+ /**
+ * Indicates whether the current checkpoints of all subtasks have exception. If it's not null,
+ * the checkpoint will fail.
+ */
+ private Throwable throwable;
+
private final ChannelStateSerializer serializer;
private final long checkpointId;
- private boolean allInputsReceived = false;
- private boolean allOutputsReceived = false;
private final RunnableWithException onComplete;
- private final int subtaskIndex;
- private final String taskName;
+
+ // Subtasks that have not yet register writer result.
+ private final Set<SubtaskID> subtasksToRegister;
+
+ private final Map<SubtaskID, ChannelStatePendingResult> pendingResults = new HashMap<>();
ChannelStateCheckpointWriter(
- String taskName,
- int subtaskIndex,
- CheckpointStartRequest startCheckpointItem,
+ Set<SubtaskID> subtasks,
+ long checkpointId,
CheckpointStreamFactory streamFactory,
ChannelStateSerializer serializer,
RunnableWithException onComplete)
throws Exception {
this(
- taskName,
- subtaskIndex,
- startCheckpointItem.getCheckpointId(),
- startCheckpointItem.getTargetResult(),
+ subtasks,
+ checkpointId,
streamFactory.createCheckpointStateOutputStream(EXCLUSIVE),
serializer,
onComplete);
@@ -95,38 +91,25 @@ class ChannelStateCheckpointWriter {
@VisibleForTesting
ChannelStateCheckpointWriter(
- String taskName,
- int subtaskIndex,
+ Set<SubtaskID> subtasks,
long checkpointId,
- ChannelStateWriteResult result,
CheckpointStateOutputStream stream,
ChannelStateSerializer serializer,
RunnableWithException onComplete) {
- this(
- taskName,
- subtaskIndex,
- checkpointId,
- result,
- serializer,
- onComplete,
- stream,
- new DataOutputStream(stream));
+ this(subtasks, checkpointId, serializer, onComplete, stream, new DataOutputStream(stream));
}
@VisibleForTesting
ChannelStateCheckpointWriter(
- String taskName,
- int subtaskIndex,
+ Set<SubtaskID> subtasks,
long checkpointId,
- ChannelStateWriteResult result,
ChannelStateSerializer serializer,
RunnableWithException onComplete,
CheckpointStateOutputStream checkpointStateOutputStream,
DataOutputStream dataStream) {
- this.taskName = taskName;
- this.subtaskIndex = subtaskIndex;
+ checkArgument(!subtasks.isEmpty(), "The subtasks cannot be empty.");
+ this.subtasksToRegister = new HashSet<>(subtasks);
this.checkpointId = checkpointId;
- this.result = checkNotNull(result);
this.checkpointStream = checkNotNull(checkpointStateOutputStream);
this.serializer = checkNotNull(serializer);
this.dataStream = checkNotNull(dataStream);
@@ -134,164 +117,160 @@ class ChannelStateCheckpointWriter {
runWithChecks(() -> serializer.writeHeader(dataStream));
}
- void writeInput(InputChannelInfo info, Buffer buffer) {
- write(
- inputChannelOffsets,
- info,
- buffer,
- !allInputsReceived,
- "ChannelStateCheckpointWriter#writeInput");
+ void registerSubtaskResult(
+ SubtaskID subtaskID, ChannelStateWriter.ChannelStateWriteResult result) {
+ // The writer shouldn't register any subtask after writer has exception or is done,
+ checkState(!isDone(), "The write is done.");
+ Preconditions.checkState(
+ !pendingResults.containsKey(subtaskID),
+ "The subtask %s has already been register before.",
+ subtaskID);
+ subtasksToRegister.remove(subtaskID);
+
+ ChannelStatePendingResult pendingResult =
+ new ChannelStatePendingResult(
+ subtaskID.getSubtaskIndex(), checkpointId, result, serializer);
+ pendingResults.put(subtaskID, pendingResult);
}
- void writeOutput(ResultSubpartitionInfo info, Buffer buffer) {
- write(
- resultSubpartitionOffsets,
- info,
- buffer,
- !allOutputsReceived,
- "ChannelStateCheckpointWriter#writeOutput");
+ void releaseSubtask(SubtaskID subtaskID) throws Exception {
+ if (subtasksToRegister.remove(subtaskID)) {
+ // If all checkpoint of other subtasks of this writer are completed, and
+ // writer is waiting for the last subtask. After the last subtask is finished,
+ // the writer should be completed.
+ tryFinishResult();
+ }
}
- private <K> void write(
- Map<K, StateContentMetaInfo> offsets,
- K key,
- Buffer buffer,
- boolean precondition,
- String action) {
+ void writeInput(
+ JobVertexID jobVertexID, int subtaskIndex, InputChannelInfo info, Buffer buffer) {
try {
- if (result.isDone()) {
+ if (isDone()) {
return;
}
- runWithChecks(
- () -> {
- checkState(precondition);
- long offset = checkpointStream.getPos();
- try (AutoCloseable ignored =
- NetworkActionsLogger.measureIO(action, buffer)) {
- serializer.writeData(dataStream, buffer);
- }
- long size = checkpointStream.getPos() - offset;
- offsets.computeIfAbsent(key, unused -> new StateContentMetaInfo())
- .withDataAdded(offset, size);
- NetworkActionsLogger.tracePersist(
- action, buffer, taskName, key, checkpointId);
- });
+ ChannelStatePendingResult pendingResult =
+ getChannelStatePendingResult(jobVertexID, subtaskIndex);
+ write(
+ pendingResult.getInputChannelOffsets(),
+ info,
+ buffer,
+ !pendingResult.isAllInputsReceived(),
+ "ChannelStateCheckpointWriter#writeInput");
} finally {
buffer.recycleBuffer();
}
}
- void completeInput() throws Exception {
- LOG.debug("complete input, output completed: {}", allOutputsReceived);
- complete(!allInputsReceived, () -> allInputsReceived = true);
+ void writeOutput(
+ JobVertexID jobVertexID, int subtaskIndex, ResultSubpartitionInfo info, Buffer buffer) {
+ try {
+ if (isDone()) {
+ return;
+ }
+ ChannelStatePendingResult pendingResult =
+ getChannelStatePendingResult(jobVertexID, subtaskIndex);
+ write(
+ pendingResult.getResultSubpartitionOffsets(),
+ info,
+ buffer,
+ !pendingResult.isAllOutputsReceived(),
+ "ChannelStateCheckpointWriter#writeOutput");
+ } finally {
+ buffer.recycleBuffer();
+ }
}
- void completeOutput() throws Exception {
- LOG.debug("complete output, input completed: {}", allInputsReceived);
- complete(!allOutputsReceived, () -> allOutputsReceived = true);
+ private <K> void write(
+ Map<K, StateContentMetaInfo> offsets,
+ K key,
+ Buffer buffer,
+ boolean precondition,
+ String action) {
+ runWithChecks(
+ () -> {
+ checkState(precondition);
+ long offset = checkpointStream.getPos();
+ try (AutoCloseable ignored = NetworkActionsLogger.measureIO(action, buffer)) {
+ serializer.writeData(dataStream, buffer);
+ }
+ long size = checkpointStream.getPos() - offset;
+ offsets.computeIfAbsent(key, unused -> new StateContentMetaInfo())
+ .withDataAdded(offset, size);
+ NetworkActionsLogger.tracePersist(action, buffer, key, checkpointId);
+ });
}
- private void complete(boolean precondition, RunnableWithException complete) throws Exception {
- if (result.isDone()) {
- // likely after abort - only need to set the flag run onComplete callback
- doComplete(precondition, complete, onComplete);
- } else {
- runWithChecks(
- () ->
- doComplete(
- precondition,
- complete,
- onComplete,
- this::finishWriteAndResult));
+ void completeInput(JobVertexID jobVertexID, int subtaskIndex) throws Exception {
+ if (isDone()) {
+ return;
}
+ getChannelStatePendingResult(jobVertexID, subtaskIndex).completeInput();
+ tryFinishResult();
}
- private void finishWriteAndResult() throws IOException {
- if (inputChannelOffsets.isEmpty() && resultSubpartitionOffsets.isEmpty()) {
- dataStream.close();
- result.inputChannelStateHandles.complete(emptyList());
- result.resultSubpartitionStateHandles.complete(emptyList());
+ void completeOutput(JobVertexID jobVertexID, int subtaskIndex) throws Exception {
+ if (isDone()) {
return;
}
- dataStream.flush();
- StreamStateHandle underlying = checkpointStream.closeAndGetHandle();
- complete(
- underlying,
- result.inputChannelStateHandles,
- inputChannelOffsets,
- HandleFactory.INPUT_CHANNEL);
- complete(
- underlying,
- result.resultSubpartitionStateHandles,
- resultSubpartitionOffsets,
- HandleFactory.RESULT_SUBPARTITION);
+ getChannelStatePendingResult(jobVertexID, subtaskIndex).completeOutput();
+ tryFinishResult();
}
- private void doComplete(
- boolean precondition,
- RunnableWithException complete,
- RunnableWithException... callbacks)
- throws Exception {
- Preconditions.checkArgument(precondition);
- complete.run();
- if (allInputsReceived && allOutputsReceived) {
- for (RunnableWithException callback : callbacks) {
- callback.run();
+ public void tryFinishResult() throws Exception {
+ if (!subtasksToRegister.isEmpty()) {
+ // Some subtasks are not registered yet
+ return;
+ }
+ for (ChannelStatePendingResult result : pendingResults.values()) {
+ if (result.isAllInputsReceived() && result.isAllOutputsReceived()) {
+ continue;
}
+ // Some subtasks did not receive all buffers
+ return;
}
- }
- private <I, H extends AbstractChannelStateHandle<I>> void complete(
- StreamStateHandle underlying,
- CompletableFuture<Collection<H>> future,
- Map<I, StateContentMetaInfo> offsets,
- HandleFactory<I, H> handleFactory)
- throws IOException {
- final Collection<H> handles = new ArrayList<>();
- for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
- handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
+ if (isDone()) {
+ // likely after abort - only need to set the flag run onComplete callback
+ doComplete(onComplete);
+ } else {
+ runWithChecks(() -> doComplete(onComplete, this::finishWriteAndResult));
}
- future.complete(handles);
- LOG.debug(
- "channel state write completed, checkpointId: {}, handles: {}",
- checkpointId,
- handles);
}
- private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
- HandleFactory<I, H> handleFactory,
- StreamStateHandle underlying,
- I channelInfo,
- StateContentMetaInfo contentMetaInfo)
- throws IOException {
- Optional<byte[]> bytes =
- underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and
- // removing this method:
- // https://issues.apache.org/jira/browse/FLINK-17972
- if (bytes.isPresent()) {
- StreamStateHandle extracted =
- new ByteStreamStateHandle(
- randomUUID().toString(),
- serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
- return handleFactory.create(
- subtaskIndex,
- channelInfo,
- extracted,
- singletonList(serializer.getHeaderLength()),
- extracted.getStateSize());
+ private void finishWriteAndResult() throws IOException {
+ StreamStateHandle stateHandle = null;
+ if (checkpointStream.getPos() == serializer.getHeaderLength()) {
+ dataStream.close();
} else {
- return handleFactory.create(
- subtaskIndex,
- channelInfo,
- underlying,
- contentMetaInfo.getOffsets(),
- contentMetaInfo.getSize());
+ dataStream.flush();
+ stateHandle = checkpointStream.closeAndGetHandle();
+ }
+ for (ChannelStatePendingResult result : pendingResults.values()) {
+ result.finishResult(stateHandle);
+ }
+ }
+
+ private void doComplete(RunnableWithException... callbacks) throws Exception {
+ for (RunnableWithException callback : callbacks) {
+ callback.run();
+ }
+ }
+
+ public boolean isDone() {
+ if (throwable != null) {
+ return true;
+ }
+ for (ChannelStatePendingResult result : pendingResults.values()) {
+ if (result.isDone()) {
+ return true;
+ }
}
+ return false;
}
private void runWithChecks(RunnableWithException r) {
try {
- checkState(!result.isDone(), "result is already completed", result);
+ checkState(!isDone(), "results are already completed", pendingResults.values());
r.run();
} catch (Exception e) {
fail(e);
@@ -301,8 +280,38 @@ class ChannelStateCheckpointWriter {
}
}
- public void fail(Throwable e) {
- result.fail(e);
+ /**
+ * The throwable is just used for specific subtask that triggered the failure. Other subtasks
+ * should fail by {@link CHANNEL_STATE_SHARED_STREAM_EXCEPTION}.
+ */
+ public void fail(JobVertexID jobVertexID, int subtaskIndex, Throwable throwable) {
+ if (isDone()) {
+ return;
+ }
+ this.throwable = throwable;
+
+ ChannelStatePendingResult result =
+ pendingResults.get(SubtaskID.of(jobVertexID, subtaskIndex));
+ if (result != null) {
+ result.fail(throwable);
+ }
+ failResultAndCloseStream(
+ new CheckpointException(CHANNEL_STATE_SHARED_STREAM_EXCEPTION, throwable));
+ }
+
+ public void fail(Throwable throwable) {
+ if (isDone()) {
+ return;
+ }
+ this.throwable = throwable;
+
+ failResultAndCloseStream(throwable);
+ }
+
+ public void failResultAndCloseStream(Throwable e) {
+ for (ChannelStatePendingResult result : pendingResults.values()) {
+ result.fail(e);
+ }
try {
checkpointStream.close();
} catch (Exception closeException) {
@@ -315,18 +324,59 @@ class ChannelStateCheckpointWriter {
}
}
- private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
- H create(
- int subtaskIndex,
- I info,
- StreamStateHandle underlying,
- List<Long> offsets,
- long size);
+ @Nonnull
+ private ChannelStatePendingResult getChannelStatePendingResult(
+ JobVertexID jobVertexID, int subtaskIndex) {
+ SubtaskID subtaskID = SubtaskID.of(jobVertexID, subtaskIndex);
+ ChannelStatePendingResult pendingResult = pendingResults.get(subtaskID);
+ checkNotNull(pendingResult, "The subtask[%s] is not registered yet", subtaskID);
+ return pendingResult;
+ }
+}
+
+/** A identification for subtask. */
+class SubtaskID {
+
+ private final JobVertexID jobVertexID;
+ private final int subtaskIndex;
+
+ private SubtaskID(JobVertexID jobVertexID, int subtaskIndex) {
+ this.jobVertexID = jobVertexID;
+ this.subtaskIndex = subtaskIndex;
+ }
+
+ public JobVertexID getJobVertexID() {
+ return jobVertexID;
+ }
+
+ public int getSubtaskIndex() {
+ return subtaskIndex;
+ }
+
+ public static SubtaskID of(JobVertexID jobVertexID, int subtaskIndex) {
+ return new SubtaskID(jobVertexID, subtaskIndex);
+ }
- HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL =
- InputChannelStateHandle::new;
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ SubtaskID subtaskID = (SubtaskID) o;
+ return subtaskIndex == subtaskID.subtaskIndex
+ && Objects.equals(jobVertexID, subtaskID.jobVertexID);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(jobVertexID, subtaskIndex);
+ }
- HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION =
- ResultSubpartitionStateHandle::new;
+ @Override
+ public String toString() {
+ return "SubtaskID{" + "jobVertexID=" + jobVertexID + ", subtaskIndex=" + subtaskIndex + '}';
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java
new file mode 100644
index 00000000000..405ce0f392b
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStatePendingResult.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.runtime.state.AbstractChannelStateHandle;
+import org.apache.flink.runtime.state.AbstractChannelStateHandle.StateContentMetaInfo;
+import org.apache.flink.runtime.state.InputChannelStateHandle;
+import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
+import org.apache.flink.runtime.state.StreamStateHandle;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+
+import static java.util.Collections.singletonList;
+import static java.util.UUID.randomUUID;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The pending result of channel state for a specific checkpoint-subtask. */
+public class ChannelStatePendingResult {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ChannelStatePendingResult.class);
+
+ // Subtask information
+ private final int subtaskIndex;
+
+ private final long checkpointId;
+
+ // Result related
+ private final ChannelStateSerializer serializer;
+ private final ChannelStateWriter.ChannelStateWriteResult result;
+ private final Map<InputChannelInfo, AbstractChannelStateHandle.StateContentMetaInfo>
+ inputChannelOffsets = new HashMap<>();
+ private final Map<ResultSubpartitionInfo, AbstractChannelStateHandle.StateContentMetaInfo>
+ resultSubpartitionOffsets = new HashMap<>();
+ private boolean allInputsReceived = false;
+ private boolean allOutputsReceived = false;
+
+ public ChannelStatePendingResult(
+ int subtaskIndex,
+ long checkpointId,
+ ChannelStateWriter.ChannelStateWriteResult result,
+ ChannelStateSerializer serializer) {
+ this.subtaskIndex = subtaskIndex;
+ this.checkpointId = checkpointId;
+ this.result = result;
+ this.serializer = serializer;
+ }
+
+ public boolean isAllInputsReceived() {
+ return allInputsReceived;
+ }
+
+ public boolean isAllOutputsReceived() {
+ return allOutputsReceived;
+ }
+
+ public Map<InputChannelInfo, StateContentMetaInfo> getInputChannelOffsets() {
+ return inputChannelOffsets;
+ }
+
+ public Map<ResultSubpartitionInfo, StateContentMetaInfo> getResultSubpartitionOffsets() {
+ return resultSubpartitionOffsets;
+ }
+
+ void completeInput() {
+ LOG.debug("complete input, output completed: {}", allOutputsReceived);
+ checkArgument(!allInputsReceived);
+ allInputsReceived = true;
+ }
+
+ void completeOutput() {
+ LOG.debug("complete output, input completed: {}", allInputsReceived);
+ checkArgument(!allOutputsReceived);
+ allOutputsReceived = true;
+ }
+
+ public void finishResult(@Nullable StreamStateHandle stateHandle) throws IOException {
+ checkState(
+ stateHandle != null
+ || (inputChannelOffsets.isEmpty() && resultSubpartitionOffsets.isEmpty()),
+ "The stateHandle just can be null when no data is written.");
+ complete(
+ stateHandle,
+ result.inputChannelStateHandles,
+ inputChannelOffsets,
+ HandleFactory.INPUT_CHANNEL);
+ complete(
+ stateHandle,
+ result.resultSubpartitionStateHandles,
+ resultSubpartitionOffsets,
+ HandleFactory.RESULT_SUBPARTITION);
+ }
+
+ private <I, H extends AbstractChannelStateHandle<I>> void complete(
+ StreamStateHandle underlying,
+ CompletableFuture<Collection<H>> future,
+ Map<I, StateContentMetaInfo> offsets,
+ HandleFactory<I, H> handleFactory)
+ throws IOException {
+ final Collection<H> handles = new ArrayList<>();
+ for (Map.Entry<I, StateContentMetaInfo> e : offsets.entrySet()) {
+ handles.add(createHandle(handleFactory, underlying, e.getKey(), e.getValue()));
+ }
+ future.complete(handles);
+ LOG.debug(
+ "channel state write completed, checkpointId: {}, handles: {}",
+ checkpointId,
+ handles);
+ }
+
+ private <I, H extends AbstractChannelStateHandle<I>> H createHandle(
+ HandleFactory<I, H> handleFactory,
+ StreamStateHandle underlying,
+ I channelInfo,
+ StateContentMetaInfo contentMetaInfo)
+ throws IOException {
+ Optional<byte[]> bytes =
+ underlying.asBytesIfInMemory(); // todo: consider restructuring channel state and
+ // removing this method:
+ // https://issues.apache.org/jira/browse/FLINK-17972
+ if (bytes.isPresent()) {
+ StreamStateHandle extracted =
+ new ByteStreamStateHandle(
+ randomUUID().toString(),
+ serializer.extractAndMerge(bytes.get(), contentMetaInfo.getOffsets()));
+ return handleFactory.create(
+ subtaskIndex,
+ channelInfo,
+ extracted,
+ singletonList(serializer.getHeaderLength()),
+ extracted.getStateSize());
+ } else {
+ return handleFactory.create(
+ subtaskIndex,
+ channelInfo,
+ underlying,
+ contentMetaInfo.getOffsets(),
+ contentMetaInfo.getSize());
+ }
+ }
+
+ public void fail(Throwable e) {
+ result.fail(e);
+ }
+
+ public boolean isDone() {
+ return this.result.isDone();
+ }
+
+ private interface HandleFactory<I, H extends AbstractChannelStateHandle<I>> {
+ H create(
+ int subtaskIndex,
+ I info,
+ StreamStateHandle underlying,
+ List<Long> offsets,
+ long size);
+
+ HandleFactory<InputChannelInfo, InputChannelStateHandle> INPUT_CHANNEL =
+ InputChannelStateHandle::new;
+
+ HandleFactory<ResultSubpartitionInfo, ResultSubpartitionStateHandle> RESULT_SUBPARTITION =
+ ResultSubpartitionStateHandle::new;
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
index 3fcd8975c61..abef241c325 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
@@ -18,7 +18,9 @@
package org.apache.flink.runtime.checkpoint.channel;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
+import org.apache.flink.runtime.io.AvailabilityProvider;
import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
@@ -27,6 +29,8 @@ import org.apache.flink.util.function.ThrowingConsumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.annotation.Nullable;
+
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
@@ -41,69 +45,151 @@ import static org.apache.flink.runtime.checkpoint.channel.CheckpointInProgressRe
import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+abstract class ChannelStateWriteRequest {
+
+ private static final Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class);
+
+ private final JobVertexID jobVertexID;
+
+ private final int subtaskIndex;
-interface ChannelStateWriteRequest {
+ private final long checkpointId;
+
+ private final String name;
- Logger LOG = LoggerFactory.getLogger(ChannelStateWriteRequest.class);
+ public ChannelStateWriteRequest(
+ JobVertexID jobVertexID, int subtaskIndex, long checkpointId, String name) {
+ this.jobVertexID = jobVertexID;
+ this.subtaskIndex = subtaskIndex;
+ this.checkpointId = checkpointId;
+ this.name = name;
+ }
+
+ public final JobVertexID getJobVertexID() {
+ return jobVertexID;
+ }
+
+ public final int getSubtaskIndex() {
+ return subtaskIndex;
+ }
+
+ public final long getCheckpointId() {
+ return checkpointId;
+ }
+
+ /**
+ * It means whether the request is ready, e.g: some requests write the channel state data
+ * future, the data future may be not ready.
+ *
+ * <p>The ready future is used for {@link ChannelStateWriteRequestExecutorImpl}, executor will
+ * process ready requests first to avoid deadlock.
+ */
+ public CompletableFuture<?> getReadyFuture() {
+ return AvailabilityProvider.AVAILABLE;
+ }
- long getCheckpointId();
+ @Override
+ public String toString() {
+ return name
+ + " {jobVertexID="
+ + jobVertexID
+ + ", subtaskIndex="
+ + subtaskIndex
+ + ", checkpointId="
+ + checkpointId
+ + '}';
+ }
- void cancel(Throwable cause) throws Exception;
+ abstract void cancel(Throwable cause) throws Exception;
- static CheckpointInProgressRequest completeInput(long checkpointId) {
+ static CheckpointInProgressRequest completeInput(
+ JobVertexID jobVertexID, int subtaskIndex, long checkpointId) {
return new CheckpointInProgressRequest(
- "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput);
+ "completeInput",
+ jobVertexID,
+ subtaskIndex,
+ checkpointId,
+ writer -> writer.completeInput(jobVertexID, subtaskIndex));
}
- static CheckpointInProgressRequest completeOutput(long checkpointId) {
+ static CheckpointInProgressRequest completeOutput(
+ JobVertexID jobVertexID, int subtaskIndex, long checkpointId) {
return new CheckpointInProgressRequest(
- "completeOutput", checkpointId, ChannelStateCheckpointWriter::completeOutput);
+ "completeOutput",
+ jobVertexID,
+ subtaskIndex,
+ checkpointId,
+ writer -> writer.completeOutput(jobVertexID, subtaskIndex));
}
static ChannelStateWriteRequest write(
- long checkpointId, InputChannelInfo info, CloseableIterator<Buffer> iterator) {
+ JobVertexID jobVertexID,
+ int subtaskIndex,
+ long checkpointId,
+ InputChannelInfo info,
+ CloseableIterator<Buffer> iterator) {
return buildWriteRequest(
+ jobVertexID,
+ subtaskIndex,
checkpointId,
"writeInput",
iterator,
- (writer, buffer) -> writer.writeInput(info, buffer));
+ (writer, buffer) -> writer.writeInput(jobVertexID, subtaskIndex, info, buffer));
}
static ChannelStateWriteRequest write(
- long checkpointId, ResultSubpartitionInfo info, Buffer... buffers) {
+ JobVertexID jobVertexID,
+ int subtaskIndex,
+ long checkpointId,
+ ResultSubpartitionInfo info,
+ Buffer... buffers) {
return buildWriteRequest(
+ jobVertexID,
+ subtaskIndex,
checkpointId,
"writeOutput",
ofElements(Buffer::recycleBuffer, buffers),
- (writer, buffer) -> writer.writeOutput(info, buffer));
+ (writer, buffer) -> writer.writeOutput(jobVertexID, subtaskIndex, info, buffer));
}
static ChannelStateWriteRequest write(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
ResultSubpartitionInfo info,
CompletableFuture<List<Buffer>> dataFuture) {
return buildFutureWriteRequest(
+ jobVertexID,
+ subtaskIndex,
checkpointId,
"writeOutputFuture",
dataFuture,
- (writer, buffer) -> writer.writeOutput(info, buffer));
+ (writer, buffer) -> writer.writeOutput(jobVertexID, subtaskIndex, info, buffer));
}
static ChannelStateWriteRequest buildFutureWriteRequest(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
String name,
CompletableFuture<List<Buffer>> dataFuture,
BiConsumer<ChannelStateCheckpointWriter, Buffer> bufferConsumer) {
return new CheckpointInProgressRequest(
name,
+ jobVertexID,
+ subtaskIndex,
checkpointId,
writer -> {
+ checkState(
+ dataFuture.isDone(), "It should be executed when dataFuture is done.");
List<Buffer> buffers;
try {
buffers = dataFuture.get();
} catch (ExecutionException e) {
// If dataFuture fails, fail only the single related writer
- writer.fail(e);
+ writer.fail(jobVertexID, subtaskIndex, e);
return;
}
for (Buffer buffer : buffers) {
@@ -122,16 +208,21 @@ interface ChannelStateWriteRequest {
"Failed to recycle the output buffer of channel state.",
e);
}
- }));
+ }),
+ dataFuture);
}
static ChannelStateWriteRequest buildWriteRequest(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
String name,
CloseableIterator<Buffer> iterator,
BiConsumer<ChannelStateCheckpointWriter, Buffer> bufferConsumer) {
return new CheckpointInProgressRequest(
name,
+ jobVertexID,
+ subtaskIndex,
checkpointId,
writer -> {
while (iterator.hasNext()) {
@@ -153,14 +244,26 @@ interface ChannelStateWriteRequest {
}
static ChannelStateWriteRequest start(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
ChannelStateWriteResult targetResult,
CheckpointStorageLocationReference locationReference) {
- return new CheckpointStartRequest(checkpointId, targetResult, locationReference);
+ return new CheckpointStartRequest(
+ jobVertexID, subtaskIndex, checkpointId, targetResult, locationReference);
+ }
+
+ static ChannelStateWriteRequest abort(
+ JobVertexID jobVertexID, int subtaskIndex, long checkpointId, Throwable cause) {
+ return new CheckpointAbortRequest(jobVertexID, subtaskIndex, checkpointId, cause);
}
- static ChannelStateWriteRequest abort(long checkpointId, Throwable cause) {
- return new CheckpointAbortRequest(checkpointId, cause);
+ static ChannelStateWriteRequest registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+ return new SubtaskRegisterRequest(jobVertexID, subtaskIndex);
+ }
+
+ static ChannelStateWriteRequest releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+ return new SubtaskReleaseRequest(jobVertexID, subtaskIndex);
}
static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) {
@@ -172,25 +275,22 @@ interface ChannelStateWriteRequest {
}
}
-final class CheckpointStartRequest implements ChannelStateWriteRequest {
+final class CheckpointStartRequest extends ChannelStateWriteRequest {
+
private final ChannelStateWriteResult targetResult;
private final CheckpointStorageLocationReference locationReference;
- private final long checkpointId;
CheckpointStartRequest(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
ChannelStateWriteResult targetResult,
CheckpointStorageLocationReference locationReference) {
- this.checkpointId = checkpointId;
+ super(jobVertexID, subtaskIndex, checkpointId, "Start");
this.targetResult = checkNotNull(targetResult);
this.locationReference = checkNotNull(locationReference);
}
- @Override
- public long getCheckpointId() {
- return checkpointId;
- }
-
ChannelStateWriteResult getTargetResult() {
return targetResult;
}
@@ -203,11 +303,6 @@ final class CheckpointStartRequest implements ChannelStateWriteRequest {
public void cancel(Throwable cause) {
targetResult.fail(cause);
}
-
- @Override
- public String toString() {
- return "start " + checkpointId;
- }
}
enum CheckpointInProgressRequestState {
@@ -218,35 +313,44 @@ enum CheckpointInProgressRequestState {
CANCELLED
}
-final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
+final class CheckpointInProgressRequest extends ChannelStateWriteRequest {
private final ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action;
private final ThrowingConsumer<Throwable, Exception> discardAction;
- private final long checkpointId;
- private final String name;
private final AtomicReference<CheckpointInProgressRequestState> state =
new AtomicReference<>(NEW);
+ @Nullable private final CompletableFuture<?> readyFuture;
CheckpointInProgressRequest(
String name,
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action) {
- this(name, checkpointId, action, unused -> {});
+ this(name, jobVertexID, subtaskIndex, checkpointId, action, unused -> {});
}
CheckpointInProgressRequest(
String name,
+ JobVertexID jobVertexID,
+ int subtaskIndex,
long checkpointId,
ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
ThrowingConsumer<Throwable, Exception> discardAction) {
- this.checkpointId = checkpointId;
- this.action = checkNotNull(action);
- this.discardAction = checkNotNull(discardAction);
- this.name = checkNotNull(name);
+ this(name, jobVertexID, subtaskIndex, checkpointId, action, discardAction, null);
}
- @Override
- public long getCheckpointId() {
- return checkpointId;
+ CheckpointInProgressRequest(
+ String name,
+ JobVertexID jobVertexID,
+ int subtaskIndex,
+ long checkpointId,
+ ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
+ ThrowingConsumer<Throwable, Exception> discardAction,
+ @Nullable CompletableFuture<?> readyFuture) {
+ super(jobVertexID, subtaskIndex, checkpointId, name);
+ this.action = checkNotNull(action);
+ this.discardAction = checkNotNull(discardAction);
+ this.readyFuture = readyFuture;
}
@Override
@@ -268,19 +372,21 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
}
@Override
- public String toString() {
- return name + " " + checkpointId;
+ public CompletableFuture<?> getReadyFuture() {
+ if (readyFuture != null) {
+ return readyFuture;
+ }
+ return super.getReadyFuture();
}
}
-final class CheckpointAbortRequest implements ChannelStateWriteRequest {
-
- private final long checkpointId;
+final class CheckpointAbortRequest extends ChannelStateWriteRequest {
private final Throwable throwable;
- public CheckpointAbortRequest(long checkpointId, Throwable throwable) {
- this.checkpointId = checkpointId;
+ public CheckpointAbortRequest(
+ JobVertexID jobVertexID, int subtaskIndex, long checkpointId, Throwable throwable) {
+ super(jobVertexID, subtaskIndex, checkpointId, "Abort");
this.throwable = throwable;
}
@@ -289,15 +395,30 @@ final class CheckpointAbortRequest implements ChannelStateWriteRequest {
}
@Override
- public long getCheckpointId() {
- return checkpointId;
+ public void cancel(Throwable cause) throws Exception {}
+
+ @Override
+ public String toString() {
+ return String.format("%s, cause : %s.", super.toString(), throwable);
+ }
+}
+
+final class SubtaskRegisterRequest extends ChannelStateWriteRequest {
+
+ public SubtaskRegisterRequest(JobVertexID jobVertexID, int subtaskIndex) {
+ super(jobVertexID, subtaskIndex, 0, "Register");
}
@Override
public void cancel(Throwable cause) throws Exception {}
+}
- @Override
- public String toString() {
- return "Abort checkpointId-" + checkpointId;
+final class SubtaskReleaseRequest extends ChannelStateWriteRequest {
+
+ public SubtaskReleaseRequest(JobVertexID jobVertexID, int subtaskIndex) {
+ super(jobVertexID, subtaskIndex, 0, "Release");
}
+
+ @Override
+ public void cancel(Throwable cause) throws Exception {}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
index fb2e4bedc79..5151d9be701 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
@@ -17,12 +17,20 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import java.io.IOException;
+import java.util.HashSet;
+import java.util.Set;
+
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -36,10 +44,15 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
private static final Logger LOG =
LoggerFactory.getLogger(ChannelStateWriteRequestDispatcherImpl.class);
- private final CheckpointStorageWorkerView streamFactoryResolver;
+ private final CheckpointStorage checkpointStorage;
+
+ private final JobID jobID;
+
private final ChannelStateSerializer serializer;
- private final int subtaskIndex;
- private final String taskName;
+
+ private final Set<SubtaskID> registeredSubtasks;
+
+ private CheckpointStorageWorkerView streamFactoryResolver;
/**
* It is the checkpointId corresponding to writer. And It should be always update with {@link
@@ -53,6 +66,9 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
*/
private long maxAbortedCheckpointId;
+ /** The aborted subtask of the maxAbortedCheckpointId. */
+ private SubtaskID abortedSubtaskID;
+
/** The aborted cause of the maxAbortedCheckpointId. */
private Throwable abortedCause;
@@ -62,14 +78,11 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
private ChannelStateCheckpointWriter writer;
ChannelStateWriteRequestDispatcherImpl(
- String taskName,
- int subtaskIndex,
- CheckpointStorageWorkerView streamFactoryResolver,
- ChannelStateSerializer serializer) {
- this.taskName = taskName;
- this.subtaskIndex = subtaskIndex;
- this.streamFactoryResolver = checkNotNull(streamFactoryResolver);
+ CheckpointStorage checkpointStorage, JobID jobID, ChannelStateSerializer serializer) {
+ this.checkpointStorage = checkNotNull(checkpointStorage);
+ this.jobID = jobID;
this.serializer = checkNotNull(serializer);
+ this.registeredSubtasks = new HashSet<>();
this.ongoingCheckpointId = -1;
this.maxAbortedCheckpointId = -1;
}
@@ -90,56 +103,95 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
}
private void dispatchInternal(ChannelStateWriteRequest request) throws Exception {
- if (isAbortedCheckpoint(request.getCheckpointId())) {
- if (request.getCheckpointId() == maxAbortedCheckpointId) {
- request.cancel(abortedCause);
- } else {
- request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+ if (request instanceof SubtaskRegisterRequest) {
+ SubtaskRegisterRequest req = (SubtaskRegisterRequest) request;
+ SubtaskID subtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+ registeredSubtasks.add(subtaskID);
+ return;
+ } else if (request instanceof SubtaskReleaseRequest) {
+ SubtaskReleaseRequest req = (SubtaskReleaseRequest) request;
+ SubtaskID subtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+ registeredSubtasks.remove(subtaskID);
+ if (writer == null) {
+ return;
}
+ writer.releaseSubtask(subtaskID);
return;
}
- if (request instanceof CheckpointStartRequest) {
- checkState(
- request.getCheckpointId() > ongoingCheckpointId,
- String.format(
- "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
- ongoingCheckpointId, request));
- failAndClearWriter(
- new IllegalStateException(
- String.format(
- "Task[name=%s, subtaskIndex=%s] has uncompleted channelState writer of checkpointId=%s, "
- + "but it received a new checkpoint start request of checkpointId=%s, it maybe "
- + "a bug due to currently not supported concurrent unaligned checkpoint.",
- taskName,
- subtaskIndex,
- ongoingCheckpointId,
- request.getCheckpointId())));
- this.writer = buildWriter((CheckpointStartRequest) request);
- this.ongoingCheckpointId = request.getCheckpointId();
+ if (isAbortedCheckpoint(request.getCheckpointId())) {
+ handleAbortedRequest(request);
+ } else if (request instanceof CheckpointStartRequest) {
+ handleCheckpointStartRequest(request);
} else if (request instanceof CheckpointInProgressRequest) {
- CheckpointInProgressRequest req = (CheckpointInProgressRequest) request;
- checkArgument(
- ongoingCheckpointId == req.getCheckpointId() && writer != null,
- "writer not found while processing request: " + req);
- req.execute(writer);
+ handleCheckpointInProgressRequest((CheckpointInProgressRequest) request);
} else if (request instanceof CheckpointAbortRequest) {
- CheckpointAbortRequest req = (CheckpointAbortRequest) request;
- if (request.getCheckpointId() > maxAbortedCheckpointId) {
- this.maxAbortedCheckpointId = req.getCheckpointId();
- this.abortedCause = req.getThrowable();
- }
-
- if (req.getCheckpointId() == ongoingCheckpointId) {
- failAndClearWriter(req.getThrowable());
- } else if (request.getCheckpointId() > ongoingCheckpointId) {
- failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
- }
+ handleCheckpointAbortRequest(request);
} else {
throw new IllegalArgumentException("unknown request type: " + request);
}
}
+ private void handleAbortedRequest(ChannelStateWriteRequest request) throws Exception {
+ if (request.getCheckpointId() != maxAbortedCheckpointId) {
+ request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+ return;
+ }
+
+ SubtaskID requestSubtask =
+ SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex());
+ if (requestSubtask.equals(abortedSubtaskID)) {
+ request.cancel(abortedCause);
+ } else {
+ request.cancel(
+ new CheckpointException(CHANNEL_STATE_SHARED_STREAM_EXCEPTION, abortedCause));
+ }
+ }
+
+ private void handleCheckpointStartRequest(ChannelStateWriteRequest request) throws Exception {
+ checkState(
+ request.getCheckpointId() >= ongoingCheckpointId,
+ String.format(
+ "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
+ ongoingCheckpointId, request));
+ if (request.getCheckpointId() > ongoingCheckpointId) {
+ // Clear the previous writer.
+ failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+ }
+ CheckpointStartRequest req = (CheckpointStartRequest) request;
+ // The writer may not be null due to other subtask may have built writer for
+ // ongoingCheckpointId when multiple subtasks share channel state file.
+ if (writer == null) {
+ this.writer = buildWriter(req);
+ this.ongoingCheckpointId = request.getCheckpointId();
+ }
+ writer.registerSubtaskResult(
+ SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex()), req.getTargetResult());
+ }
+
+ private void handleCheckpointInProgressRequest(CheckpointInProgressRequest req)
+ throws Exception {
+ checkArgument(
+ ongoingCheckpointId == req.getCheckpointId() && writer != null,
+ "writer not found while processing request: " + req);
+ req.execute(writer);
+ }
+
+ private void handleCheckpointAbortRequest(ChannelStateWriteRequest request) {
+ CheckpointAbortRequest req = (CheckpointAbortRequest) request;
+ if (request.getCheckpointId() > maxAbortedCheckpointId) {
+ this.maxAbortedCheckpointId = req.getCheckpointId();
+ this.abortedCause = req.getThrowable();
+ this.abortedSubtaskID = SubtaskID.of(req.getJobVertexID(), req.getSubtaskIndex());
+ }
+
+ if (req.getCheckpointId() == ongoingCheckpointId) {
+ failAndClearWriter(req.getJobVertexID(), req.getSubtaskIndex(), req.getThrowable());
+ } else if (request.getCheckpointId() > ongoingCheckpointId) {
+ failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+ }
+ }
+
private boolean isAbortedCheckpoint(long checkpointId) {
return checkpointId < ongoingCheckpointId || checkpointId <= maxAbortedCheckpointId;
}
@@ -152,14 +204,23 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
writer = null;
}
+ private void failAndClearWriter(
+ JobVertexID jobVertexID, int subtaskIndex, Throwable throwable) {
+ if (writer == null) {
+ return;
+ }
+ writer.fail(jobVertexID, subtaskIndex, throwable);
+ writer = null;
+ }
+
private ChannelStateCheckpointWriter buildWriter(CheckpointStartRequest request)
throws Exception {
return new ChannelStateCheckpointWriter(
- taskName,
- subtaskIndex,
- request,
- streamFactoryResolver.resolveCheckpointStorageLocation(
- request.getCheckpointId(), request.getLocationReference()),
+ registeredSubtasks,
+ request.getCheckpointId(),
+ getStreamFactoryResolver()
+ .resolveCheckpointStorageLocation(
+ request.getCheckpointId(), request.getLocationReference()),
serializer,
() -> {
checkState(
@@ -183,4 +244,11 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
}
writer = null;
}
+
+ CheckpointStorageWorkerView getStreamFactoryResolver() throws IOException {
+ if (streamFactoryResolver == null) {
+ streamFactoryResolver = checkpointStorage.createCheckpointStorage(jobID);
+ }
+ return streamFactoryResolver;
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
index 52256873f93..42d0142f112 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutor.java
@@ -17,15 +17,17 @@
package org.apache.flink.runtime.checkpoint.channel;
-import java.io.Closeable;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
+import java.io.IOException;
/**
* Executes {@link ChannelStateWriteRequest}s potentially asynchronously. An exception thrown during
* the execution should be re-thrown on any next call.
*/
-interface ChannelStateWriteRequestExecutor extends Closeable {
+interface ChannelStateWriteRequestExecutor {
- /** @throws IllegalStateException if called more than once or after {@link #close()} */
+ /** @throws IllegalStateException if called more than once or after {@link #releaseSubtask} */
void start() throws IllegalStateException;
/**
@@ -45,4 +47,10 @@ interface ChannelStateWriteRequestExecutor extends Closeable {
* @throws Exception if any exception occurred during processing this or other items previously
*/
void submitPriority(ChannelStateWriteRequest r) throws Exception;
+
+ /** Register subtask. */
+ void registerSubtask(JobVertexID jobVertexID, int subtaskIndex);
+
+ /** Release the subtask. */
+ void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) throws IOException;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java
new file mode 100644
index 00000000000..c41e96a24f5
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactory.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
+
+import javax.annotation.concurrent.GuardedBy;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The factory of {@link ChannelStateWriteRequestExecutor}. */
+public class ChannelStateWriteRequestExecutorFactory {
+
+ private final JobID jobID;
+
+ private final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private ChannelStateWriteRequestExecutor executor;
+
+ public ChannelStateWriteRequestExecutorFactory(JobID jobID) {
+ this.jobID = jobID;
+ }
+
+ public ChannelStateWriteRequestExecutor getOrCreateExecutor(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
+ CheckpointStorage checkpointStorage,
+ int maxSubtasksPerChannelStateFile) {
+ synchronized (lock) {
+ if (executor == null) {
+ executor =
+ new ChannelStateWriteRequestExecutorImpl(
+ new ChannelStateWriteRequestDispatcherImpl(
+ checkpointStorage, jobID, new ChannelStateSerializerImpl()),
+ maxSubtasksPerChannelStateFile,
+ executor -> {
+ synchronized (lock) {
+ checkState(this.executor == executor);
+ this.executor = null;
+ }
+ });
+ executor.start();
+ }
+ ChannelStateWriteRequestExecutor currentExecutor = executor;
+ currentExecutor.registerSubtask(jobVertexID, subtaskIndex);
+ return currentExecutor;
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
index 5c12a778d64..bafc5b69c8c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImpl.java
@@ -18,23 +18,37 @@
package org.apache.flink.runtime.checkpoint.channel;
import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.core.fs.FileSystemSafetyNet;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.function.RunnableWithException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import java.io.IOException;
+import java.util.ArrayDeque;
import java.util.ArrayList;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
-import java.util.concurrent.BlockingDeque;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Queue;
+import java.util.Set;
import java.util.concurrent.CancellationException;
-import java.util.concurrent.LinkedBlockingDeque;
+import java.util.function.Consumer;
import java.util.stream.Collectors;
import static org.apache.flink.util.IOUtils.closeAll;
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
/**
* Executes {@link ChannelStateWriteRequest}s in a separate thread. Any exception occurred during
@@ -46,32 +60,59 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
private static final Logger LOG =
LoggerFactory.getLogger(ChannelStateWriteRequestExecutorImpl.class);
+ private final Object lock = new Object();
+
private final ChannelStateWriteRequestDispatcher dispatcher;
- private final BlockingDeque<ChannelStateWriteRequest> deque;
private final Thread thread;
- private volatile Exception thrown = null;
- private volatile boolean wasClosed = false;
- private final String taskName;
+
+ private final int maxSubtasksPerChannelStateFile;
+
+ @GuardedBy("lock")
+ private final Deque<ChannelStateWriteRequest> deque;
+
+ @GuardedBy("lock")
+ private Exception thrown = null;
+
+ @GuardedBy("lock")
+ private boolean wasClosed = false;
+
+ @GuardedBy("lock")
+ private final Map<SubtaskID, Queue<ChannelStateWriteRequest>> unreadyQueues = new HashMap<>();
+
+ @GuardedBy("lock")
+ private boolean isRegistering = true;
+
+ @GuardedBy("lock")
+ private final Set<SubtaskID> subtasks;
+
+ /** It cannot be called inside the {@link #lock} to avoid the deadlock. */
+ private final Consumer<ChannelStateWriteRequestExecutor> onRegistered;
ChannelStateWriteRequestExecutorImpl(
- String taskName, ChannelStateWriteRequestDispatcher dispatcher) {
- this(taskName, dispatcher, new LinkedBlockingDeque<>());
+ ChannelStateWriteRequestDispatcher dispatcher,
+ int maxSubtasksPerChannelStateFile,
+ Consumer<ChannelStateWriteRequestExecutor> onRegistered) {
+ this(dispatcher, new ArrayDeque<>(), maxSubtasksPerChannelStateFile, onRegistered);
}
ChannelStateWriteRequestExecutorImpl(
- String taskName,
ChannelStateWriteRequestDispatcher dispatcher,
- BlockingDeque<ChannelStateWriteRequest> deque) {
- this.taskName = taskName;
+ Deque<ChannelStateWriteRequest> deque,
+ int maxSubtasksPerChannelStateFile,
+ Consumer<ChannelStateWriteRequestExecutor> onRegistered) {
this.dispatcher = dispatcher;
this.deque = deque;
- this.thread = new Thread(this::run, "Channel state writer " + taskName);
+ this.maxSubtasksPerChannelStateFile = maxSubtasksPerChannelStateFile;
+ this.onRegistered = onRegistered;
+ this.thread = new Thread(this::run, "Channel state writer ");
+ this.subtasks = new HashSet<>();
this.thread.setDaemon(true);
}
@VisibleForTesting
void run() {
try {
+ FileSystemSafetyNet.initializeSafetyNetForThread();
loop();
} catch (Exception ex) {
thrown = ex;
@@ -79,39 +120,94 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
try {
closeAll(
this::cleanupRequests,
- () ->
- dispatcher.fail(
- thrown == null ? new CancellationException() : thrown));
+ () -> {
+ Throwable cause;
+ synchronized (lock) {
+ cause = thrown == null ? new CancellationException() : thrown;
+ }
+ dispatcher.fail(cause);
+ });
} catch (Exception e) {
- //noinspection NonAtomicOperationOnVolatileField
- thrown = ExceptionUtils.firstOrSuppressed(e, thrown);
+ synchronized (lock) {
+ //noinspection NonAtomicOperationOnVolatileField
+ thrown = ExceptionUtils.firstOrSuppressed(e, thrown);
+ }
}
+ FileSystemSafetyNet.closeSafetyNetAndGuardedResourcesForThread();
}
- LOG.debug("{} loop terminated", taskName);
+ LOG.debug("loop terminated");
}
private void loop() throws Exception {
- while (!wasClosed) {
+ while (true) {
try {
- dispatcher.dispatch(deque.take());
+ ChannelStateWriteRequest request;
+ boolean completeRegister = false;
+ synchronized (lock) {
+ request = waitAndTakeUnsafe();
+ if (request == null) {
+ // The executor is closed, so return directly.
+ return;
+ }
+ // The executor will end the registration, when the start request comes.
+ // Because the checkpoint can be started after all tasks are initiated.
+ if (request instanceof CheckpointStartRequest) {
+ completeRegister = completeRegister();
+ }
+ }
+ if (completeRegister) {
+ onRegistered.accept(this);
+ }
+ dispatcher.dispatch(request);
} catch (InterruptedException e) {
- if (!wasClosed) {
- LOG.debug(
- taskName
- + " interrupted while waiting for a request (continue waiting)",
- e);
- } else {
- Thread.currentThread().interrupt();
+ synchronized (lock) {
+ if (!wasClosed) {
+ LOG.debug(
+ "Channel state executor is interrupted while waiting for a request (continue waiting)",
+ e);
+ } else {
+ Thread.currentThread().interrupt();
+ return;
+ }
}
}
}
}
+ /**
+ * Retrieves and removes the head request of the {@link #deque}, waiting if necessary until an
+ * element becomes available.
+ *
+ * @return The head request, it can be null when the executor is closed.
+ */
+ @Nullable
+ private ChannelStateWriteRequest waitAndTakeUnsafe() throws InterruptedException {
+ ChannelStateWriteRequest request;
+ while (!wasClosed) {
+ request = deque.pollFirst();
+ if (request == null) {
+ lock.wait();
+ } else {
+ return request;
+ }
+ }
+ return null;
+ }
+
private void cleanupRequests() throws Exception {
- Throwable cause = thrown == null ? new CancellationException() : thrown;
- List<ChannelStateWriteRequest> drained = new ArrayList<>();
- deque.drainTo(drained);
- LOG.info("{} discarding {} drained requests", taskName, drained.size());
+ List<ChannelStateWriteRequest> drained;
+ Throwable cause;
+ synchronized (lock) {
+ cause = thrown == null ? new CancellationException() : thrown;
+ drained = new ArrayList<>(deque);
+ deque.clear();
+ for (Queue<ChannelStateWriteRequest> unreadyQueue : unreadyQueues.values()) {
+ while (!unreadyQueue.isEmpty()) {
+ drained.add(unreadyQueue.poll());
+ }
+ }
+ }
+ LOG.info("discarding {} drained requests", drained.size());
closeAll(
drained.stream()
.<AutoCloseable>map(request -> () -> request.cancel(cause))
@@ -125,12 +221,91 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
@Override
public void submit(ChannelStateWriteRequest request) throws Exception {
- submitInternal(request, () -> deque.add(request));
+ synchronized (lock) {
+ Queue<ChannelStateWriteRequest> unreadyQueue =
+ unreadyQueues.get(
+ SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex()));
+ checkArgument(unreadyQueue != null, "The subtask %s is not yet registered.");
+ submitInternal(
+ request,
+ () -> {
+ // 1. unreadyQueue isn't empty, the new request must keep the order, so add
+ // the new request to the unreadyQueue tail.
+ if (!unreadyQueue.isEmpty()) {
+ unreadyQueue.add(request);
+ return;
+ }
+ // 2. unreadyQueue is empty, and new request is ready, so add it to the
+ // readyQueue directly.
+ if (request.getReadyFuture().isDone()) {
+ deque.add(request);
+ lock.notifyAll();
+ return;
+ }
+ // 3. unreadyQueue is empty, and new request isn't ready, so add it to the
+ // unreadyQueue, and register it as the first request.
+ unreadyQueue.add(request);
+ registerFirstRequestFuture(request, unreadyQueue);
+ });
+ }
+ }
+
+ private void registerFirstRequestFuture(
+ @Nonnull ChannelStateWriteRequest firstRequest,
+ Queue<ChannelStateWriteRequest> unreadyQueue) {
+ assert Thread.holdsLock(lock);
+ checkState(firstRequest == unreadyQueue.peek(), "The request isn't the first request.");
+
+ firstRequest
+ .getReadyFuture()
+ .thenAccept(
+ o -> {
+ synchronized (lock) {
+ moveReadyRequestToReadyQueue(unreadyQueue, firstRequest);
+ }
+ })
+ .exceptionally(
+ throwable -> {
+ // When dataFuture is completed, just move the request to readyQueue.
+ // And the throwable doesn't need to be handled here, it will be handled
+ // in channel state writer thread later.
+ synchronized (lock) {
+ moveReadyRequestToReadyQueue(unreadyQueue, firstRequest);
+ }
+ return null;
+ });
+ }
+
+ private void moveReadyRequestToReadyQueue(
+ Queue<ChannelStateWriteRequest> unreadyQueue, ChannelStateWriteRequest firstRequest) {
+ assert Thread.holdsLock(lock);
+ checkState(firstRequest == unreadyQueue.peek());
+ while (!unreadyQueue.isEmpty()) {
+ ChannelStateWriteRequest req = unreadyQueue.peek();
+ if (!req.getReadyFuture().isDone()) {
+ registerFirstRequestFuture(req, unreadyQueue);
+ return;
+ }
+ deque.add(Objects.requireNonNull(unreadyQueue.poll()));
+ lock.notifyAll();
+ }
}
@Override
public void submitPriority(ChannelStateWriteRequest request) throws Exception {
- submitInternal(request, () -> deque.addFirst(request));
+ synchronized (lock) {
+ checkArgument(
+ unreadyQueues.containsKey(
+ SubtaskID.of(request.getJobVertexID(), request.getSubtaskIndex())),
+ "The subtask %s is not yet registered.");
+ checkState(request.getReadyFuture().isDone(), "The priority request must be ready.");
+ submitInternal(
+ request,
+ () -> {
+ deque.addFirst(request);
+ lock.notifyAll();
+ });
+ }
}
private void submitInternal(ChannelStateWriteRequest request, RunnableWithException action)
@@ -145,6 +320,7 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
}
private void ensureRunning() throws Exception {
+ assert Thread.holdsLock(lock);
// this check should be performed *at least after* enqueuing a request
// checking before is not enough because (check + enqueue) is not atomic
if (wasClosed || !thread.isAlive()) {
@@ -158,8 +334,63 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
}
@Override
- public void close() throws IOException {
- wasClosed = true;
+ public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+ SubtaskID subtaskID = SubtaskID.of(jobVertexID, subtaskIndex);
+ boolean completeRegister = false;
+ synchronized (lock) {
+ checkState(isRegistering(), "This executor has been registered.");
+ checkState(
+ !subtasks.contains(subtaskID),
+ String.format("This subtask[%s] has already registered.", subtaskID));
+ subtasks.add(subtaskID);
+ deque.add(
+ ChannelStateWriteRequest.registerSubtask(
+ subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex()));
+ lock.notifyAll();
+ unreadyQueues.put(subtaskID, new ArrayDeque<>());
+ if (subtasks.size() == maxSubtasksPerChannelStateFile) {
+ completeRegister = completeRegister();
+ }
+ }
+ if (completeRegister) {
+ onRegistered.accept(this);
+ }
+ }
+
+ @VisibleForTesting
+ public boolean isRegistering() {
+ synchronized (lock) {
+ return isRegistering;
+ }
+ }
+
+ private boolean completeRegister() {
+ assert Thread.holdsLock(lock);
+ if (isRegistering) {
+ isRegistering = false;
+ return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) throws IOException {
+ boolean completeRegister = false;
+ try {
+ synchronized (lock) {
+ completeRegister = completeRegister();
+ subtasks.remove(SubtaskID.of(jobVertexID, subtaskIndex));
+ if (!subtasks.isEmpty()) {
+ return;
+ }
+ wasClosed = true;
+ lock.notifyAll();
+ }
+ } finally {
+ if (completeRegister) {
+ onRegistered.accept(this);
+ }
+ }
while (thread.isAlive()) {
thread.interrupt();
try {
@@ -168,11 +399,15 @@ class ChannelStateWriteRequestExecutorImpl implements ChannelStateWriteRequestEx
if (!thread.isAlive()) {
Thread.currentThread().interrupt();
}
- LOG.debug(taskName + " interrupted while waiting for the writer thread to die", e);
+ LOG.debug(
+ "Channel state executor is interrupted while waiting for the writer thread to die",
+ e);
}
}
- if (thrown != null) {
- throw new IOException(thrown);
+ synchronized (lock) {
+ if (thrown != null) {
+ throw new IOException(thrown);
+ }
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
index b1091d2b2bf..9b80813cf81 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImpl.java
@@ -21,8 +21,9 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
-import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
+import org.apache.flink.runtime.state.CheckpointStorage;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.Preconditions;
@@ -36,6 +37,7 @@ import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicBoolean;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput;
@@ -65,18 +67,37 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
1000; // includes max-concurrent-checkpoints + checkpoints to be aborted (scheduled via
// mailbox)
+ private final JobVertexID jobVertexID;
+
+ private final int subtaskIndex;
+
private final String taskName;
+
private final ChannelStateWriteRequestExecutor executor;
private final ConcurrentMap<Long, ChannelStateWriteResult> results;
private final int maxCheckpoints;
+ private final AtomicBoolean wasClosed = new AtomicBoolean(false);
+
/**
* Creates a {@link ChannelStateWriterImpl} with {@link #DEFAULT_MAX_CHECKPOINTS} as {@link
* #maxCheckpoints}.
*/
public ChannelStateWriterImpl(
- String taskName, int subtaskIndex, CheckpointStorageWorkerView streamFactoryResolver) {
- this(taskName, subtaskIndex, streamFactoryResolver, DEFAULT_MAX_CHECKPOINTS);
+ JobVertexID jobVertexID,
+ String taskName,
+ int subtaskIndex,
+ CheckpointStorage checkpointStorage,
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory,
+ int maxSubtasksPerChannelStateFile) {
+ this(
+ jobVertexID,
+ taskName,
+ subtaskIndex,
+ checkpointStorage,
+ DEFAULT_MAX_CHECKPOINTS,
+ channelStateExecutorFactory,
+ maxSubtasksPerChannelStateFile);
}
/**
@@ -84,34 +105,41 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
* {@link ChannelStateSerializer}, and a {@link ChannelStateWriteRequestExecutorImpl}.
*
* @param taskName
- * @param streamFactoryResolver a factory to obtain output stream factory for a given checkpoint
+ * @param checkpointStorage a factory to obtain output stream factory for a given checkpoint
* @param maxCheckpoints maximum number of checkpoints to be written currently or finished but
* not taken yet.
*/
ChannelStateWriterImpl(
+ JobVertexID jobVertexID,
String taskName,
int subtaskIndex,
- CheckpointStorageWorkerView streamFactoryResolver,
- int maxCheckpoints) {
+ CheckpointStorage checkpointStorage,
+ int maxCheckpoints,
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory,
+ int maxSubtasksPerChannelStateFile) {
this(
+ jobVertexID,
taskName,
+ subtaskIndex,
new ConcurrentHashMap<>(maxCheckpoints),
- new ChannelStateWriteRequestExecutorImpl(
- taskName,
- new ChannelStateWriteRequestDispatcherImpl(
- taskName,
- subtaskIndex,
- streamFactoryResolver,
- new ChannelStateSerializerImpl())),
+ channelStateExecutorFactory.getOrCreateExecutor(
+ jobVertexID,
+ subtaskIndex,
+ checkpointStorage,
+ maxSubtasksPerChannelStateFile),
maxCheckpoints);
}
ChannelStateWriterImpl(
+ JobVertexID jobVertexID,
String taskName,
+ int subtaskIndex,
ConcurrentMap<Long, ChannelStateWriteResult> results,
ChannelStateWriteRequestExecutor executor,
int maxCheckpoints) {
+ this.jobVertexID = jobVertexID;
this.taskName = taskName;
+ this.subtaskIndex = subtaskIndex;
this.results = results;
this.maxCheckpoints = maxCheckpoints;
this.executor = executor;
@@ -135,6 +163,8 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
maxCheckpoints));
enqueue(
new CheckpointStartRequest(
+ jobVertexID,
+ subtaskIndex,
checkpointId,
result,
checkpointOptions.getTargetLocation()),
@@ -158,7 +188,7 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
checkpointId,
info,
startSeqNum);
- enqueue(write(checkpointId, info, iterator), false);
+ enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, iterator), false);
}
@Override
@@ -171,7 +201,7 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
info,
startSeqNum,
data == null ? 0 : data.length);
- enqueue(write(checkpointId, info, data), false);
+ enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, data), false);
}
@Override
@@ -187,27 +217,29 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
checkpointId,
info,
startSeqNum);
- enqueue(write(checkpointId, info, dataFuture), false);
+ enqueue(write(jobVertexID, subtaskIndex, checkpointId, info, dataFuture), false);
}
@Override
public void finishInput(long checkpointId) {
LOG.debug("{} finishing input data, checkpoint {}", taskName, checkpointId);
- enqueue(completeInput(checkpointId), false);
+ enqueue(completeInput(jobVertexID, subtaskIndex, checkpointId), false);
}
@Override
public void finishOutput(long checkpointId) {
LOG.debug("{} finishing output data, checkpoint {}", taskName, checkpointId);
- enqueue(completeOutput(checkpointId), false);
+ enqueue(completeOutput(jobVertexID, subtaskIndex, checkpointId), false);
}
@Override
public void abort(long checkpointId, Throwable cause, boolean cleanup) {
LOG.debug("{} aborting, checkpoint {}", taskName, checkpointId);
- enqueue(ChannelStateWriteRequest.abort(checkpointId, cause), true); // abort already started
enqueue(
- ChannelStateWriteRequest.abort(checkpointId, cause),
+ ChannelStateWriteRequest.abort(jobVertexID, subtaskIndex, checkpointId, cause),
+ true); // abort already started
+ enqueue(
+ ChannelStateWriteRequest.abort(jobVertexID, subtaskIndex, checkpointId, cause),
false); // abort enqueued but not started
if (cleanup) {
results.remove(checkpointId);
@@ -230,15 +262,14 @@ public class ChannelStateWriterImpl implements ChannelStateWriter {
return results.get(checkpointId);
}
- public void open() {
- executor.start();
- }
-
@Override
public void close() throws IOException {
- LOG.debug("close, dropping checkpoints {}", results.keySet());
- results.clear();
- executor.close();
+ if (wasClosed.compareAndSet(false, true)) {
+ LOG.debug("close, dropping checkpoints {}", results.keySet());
+ results.clear();
+ enqueue(ChannelStateWriteRequest.releaseSubtask(jobVertexID, subtaskIndex), false);
+ executor.releaseSubtask(jobVertexID, subtaskIndex);
+ }
}
private void enqueue(ChannelStateWriteRequest request, boolean atTheFront) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
index 40a587fd5c3..bd8d366f29a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -261,4 +262,6 @@ public interface Environment {
default CheckpointStorageAccess getCheckpointStorageAccess() {
throw new UnsupportedOperationException();
}
+
+ ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory();
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
index 725182bebbf..204f3d3c074 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/logger/NetworkActionsLogger.java
@@ -97,11 +97,10 @@ public class NetworkActionsLogger {
}
public static void tracePersist(
- String action, Buffer buffer, String taskName, Object channelInfo, long checkpointId) {
+ String action, Buffer buffer, Object channelInfo, long checkpointId) {
if (LOG.isTraceEnabled()) {
LOG.trace(
- "[{}] {} {}, checkpoint {} @ {}",
- taskName,
+ "{} {}, checkpoint {} @ {}",
action,
buffer.toDebugString(INCLUDE_HASH),
checkpointId,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java
new file mode 100644
index 00000000000..db5ff6d416b
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManager.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import javax.annotation.concurrent.GuardedBy;
+import javax.annotation.concurrent.ThreadSafe;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * This class holds the all {@link ChannelStateWriteRequestExecutorFactory} objects for a task
+ * executor (manager).
+ */
+@ThreadSafe
+public class TaskExecutorChannelStateExecutorFactoryManager {
+
+ /** Logger for this class. */
+ private static final Logger LOG =
+ LoggerFactory.getLogger(TaskExecutorChannelStateExecutorFactoryManager.class);
+
+ private final Object lock = new Object();
+
+ @GuardedBy("lock")
+ private final Map<JobID, ChannelStateWriteRequestExecutorFactory> executorFactoryByJobId;
+
+ @GuardedBy("lock")
+ private boolean closed;
+
+ public TaskExecutorChannelStateExecutorFactoryManager() {
+ this.executorFactoryByJobId = new HashMap<>();
+ this.closed = false;
+ }
+
+ public ChannelStateWriteRequestExecutorFactory getOrCreateExecutorFactory(
+ @Nonnull JobID jobID) {
+ synchronized (lock) {
+ if (closed) {
+ throw new IllegalStateException(
+ "TaskExecutorChannelStateExecutorFactoryManager is already closed and cannot "
+ + "get a new executor factory.");
+ }
+ ChannelStateWriteRequestExecutorFactory factory = executorFactoryByJobId.get(jobID);
+ if (factory == null) {
+ LOG.info("Creating the channel state executor factory for job id {}", jobID);
+ factory = new ChannelStateWriteRequestExecutorFactory(jobID);
+ executorFactoryByJobId.put(jobID, factory);
+ }
+ return factory;
+ }
+ }
+
+ public void releaseResourcesForJob(@Nonnull JobID jobID) {
+ LOG.debug("Releasing the factory under job id {}", jobID);
+ synchronized (lock) {
+ if (closed) {
+ return;
+ }
+ executorFactoryByJobId.remove(jobID);
+ }
+ }
+
+ public void shutdown() {
+ synchronized (lock) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ executorFactoryByJobId.clear();
+ LOG.info("Shutting down TaskExecutorChannelStateExecutorFactoryManager.");
+ }
+ }
+
+ @VisibleForTesting
+ @Nullable
+ public ChannelStateWriteRequestExecutorFactory getFactoryByJobId(JobID jobId) {
+ synchronized (lock) {
+ return executorFactoryByJobId.get(jobId);
+ }
+ }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
index 9c3dd3744af..e7226b9abcd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskExecutor.java
@@ -97,6 +97,7 @@ import org.apache.flink.runtime.rpc.RpcServiceUtils;
import org.apache.flink.runtime.security.token.DelegationTokenReceiverRepository;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
import org.apache.flink.runtime.state.TaskLocalStateStore;
@@ -217,6 +218,12 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
/** The changelog manager for this task, providing changelog storage per job. */
private final TaskExecutorStateChangelogStoragesManager changelogStoragesManager;
+ /**
+ * The channel state executor factory manager for this task, providing channel state executor
+ * factory per job.
+ */
+ private final TaskExecutorChannelStateExecutorFactoryManager channelStateExecutorFactoryManager;
+
/** Information provider for external resources. */
private final ExternalResourceInfoProvider externalResourceInfoProvider;
@@ -319,6 +326,8 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
taskExecutorServices.getUnresolvedTaskManagerLocation();
this.localStateStoresManager = taskExecutorServices.getTaskManagerStateStore();
this.changelogStoragesManager = taskExecutorServices.getTaskManagerChangelogManager();
+ this.channelStateExecutorFactoryManager =
+ taskExecutorServices.getTaskManagerChannelStateManager();
this.shuffleEnvironment = taskExecutorServices.getShuffleEnvironment();
this.kvStateService = taskExecutorServices.getKvStateService();
this.ioExecutor = taskExecutorServices.getIOExecutor();
@@ -479,6 +488,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
}
changelogStoragesManager.shutdown();
+ channelStateExecutorFactoryManager.shutdown();
Preconditions.checkState(jobTable.isEmpty());
@@ -771,7 +781,8 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
taskManagerConfiguration,
taskMetricGroup,
partitionStateChecker,
- getRpcService().getScheduledExecutor());
+ getRpcService().getScheduledExecutor(),
+ channelStateExecutorFactoryManager.getOrCreateExecutorFactory(jobId));
taskMetricGroup.gauge(MetricNames.IS_BACK_PRESSURED, task::isBackPressured);
@@ -1871,6 +1882,7 @@ public class TaskExecutor extends RpcEndpoint implements TaskExecutorGateway {
taskManagerMetricGroup.removeJobMetricsGroup(jobId);
changelogStoragesManager.releaseResourcesForJob(jobId);
currentSlotOfferPerJob.remove(jobId);
+ channelStateExecutorFactoryManager.releaseResourcesForJob(jobId);
}
private void scheduleResultPartitionCleanup(JobID jobId) {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
index d439146f263..bf4e11486b5 100755
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskexecutor/TaskManagerServices.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.rpc.FatalErrorHandler;
import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
import org.apache.flink.runtime.shuffle.ShuffleServiceLoader;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
import org.apache.flink.runtime.taskexecutor.slot.DefaultTimerService;
@@ -80,6 +81,7 @@ public class TaskManagerServices {
private final JobLeaderService jobLeaderService;
private final TaskExecutorLocalStateStoresManager taskManagerStateStore;
private final TaskExecutorStateChangelogStoragesManager taskManagerChangelogManager;
+ private final TaskExecutorChannelStateExecutorFactoryManager taskManagerChannelStateManager;
private final TaskEventDispatcher taskEventDispatcher;
private final ExecutorService ioExecutor;
private final LibraryCacheManager libraryCacheManager;
@@ -98,6 +100,7 @@ public class TaskManagerServices {
JobLeaderService jobLeaderService,
TaskExecutorLocalStateStoresManager taskManagerStateStore,
TaskExecutorStateChangelogStoragesManager taskManagerChangelogManager,
+ TaskExecutorChannelStateExecutorFactoryManager taskManagerChannelStateManager,
TaskEventDispatcher taskEventDispatcher,
ExecutorService ioExecutor,
LibraryCacheManager libraryCacheManager,
@@ -116,6 +119,7 @@ public class TaskManagerServices {
this.jobLeaderService = Preconditions.checkNotNull(jobLeaderService);
this.taskManagerStateStore = Preconditions.checkNotNull(taskManagerStateStore);
this.taskManagerChangelogManager = Preconditions.checkNotNull(taskManagerChangelogManager);
+ this.taskManagerChannelStateManager = taskManagerChannelStateManager;
this.taskEventDispatcher = Preconditions.checkNotNull(taskEventDispatcher);
this.ioExecutor = Preconditions.checkNotNull(ioExecutor);
this.libraryCacheManager = Preconditions.checkNotNull(libraryCacheManager);
@@ -171,6 +175,10 @@ public class TaskManagerServices {
return taskManagerChangelogManager;
}
+ public TaskExecutorChannelStateExecutorFactoryManager getTaskManagerChannelStateManager() {
+ return taskManagerChannelStateManager;
+ }
+
public TaskEventDispatcher getTaskEventDispatcher() {
return taskEventDispatcher;
}
@@ -341,6 +349,9 @@ public class TaskManagerServices {
final TaskExecutorStateChangelogStoragesManager changelogStoragesManager =
new TaskExecutorStateChangelogStoragesManager();
+ final TaskExecutorChannelStateExecutorFactoryManager channelStateExecutorFactoryManager =
+ new TaskExecutorChannelStateExecutorFactoryManager();
+
final boolean failOnJvmMetaspaceOomError =
taskManagerServicesConfiguration
.getConfiguration()
@@ -382,6 +393,7 @@ public class TaskManagerServices {
jobLeaderService,
taskStateManager,
changelogStoragesManager,
+ channelStateExecutorFactoryManager,
taskEventDispatcher,
ioExecutor,
libraryCacheManager,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
index 74ce5e3fd32..e1d28ab359c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -106,6 +107,8 @@ public class RuntimeEnvironment implements Environment {
@Nullable private CheckpointStorageAccess checkpointStorageAccess;
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
// ------------------------------------------------------------------------
public RuntimeEnvironment(
@@ -135,7 +138,8 @@ public class RuntimeEnvironment implements Environment {
TaskManagerRuntimeInfo taskManagerInfo,
TaskMetricGroup metrics,
Task containingTask,
- ExternalResourceInfoProvider externalResourceInfoProvider) {
+ ExternalResourceInfoProvider externalResourceInfoProvider,
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
this.jobId = checkNotNull(jobId);
this.jobVertexId = checkNotNull(jobVertexId);
@@ -164,6 +168,7 @@ public class RuntimeEnvironment implements Environment {
this.containingTask = containingTask;
this.metrics = metrics;
this.externalResourceInfoProvider = checkNotNull(externalResourceInfoProvider);
+ this.channelStateExecutorFactory = checkNotNull(channelStateExecutorFactory);
}
// ------------------------------------------------------------------------
@@ -368,4 +373,9 @@ public class RuntimeEnvironment implements Environment {
return checkNotNull(
checkpointStorageAccess, "checkpointStorage has not been initialized yet!");
}
+
+ @Override
+ public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+ return channelStateExecutorFactory;
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
index 059534f346e..acdf1b00d56 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java
@@ -37,6 +37,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointStoreUtil;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -262,6 +263,9 @@ public class Task
/** Future that is completed once {@link #run()} exits. */
private final CompletableFuture<ExecutionState> terminationFuture = new CompletableFuture<>();
+ /** The factory of channel state write request executor. */
+ private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
// ------------------------------------------------------------------------
// Fields that control the task execution. All these fields are volatile
// (which means that they introduce memory barriers), to establish
@@ -324,7 +328,8 @@ public class Task
TaskManagerRuntimeInfo taskManagerConfig,
@Nonnull TaskMetricGroup metricGroup,
PartitionProducerStateChecker partitionProducerStateChecker,
- Executor executor) {
+ Executor executor,
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
Preconditions.checkNotNull(jobInformation);
Preconditions.checkNotNull(taskInformation);
@@ -382,6 +387,7 @@ public class Task
this.partitionProducerStateChecker =
Preconditions.checkNotNull(partitionProducerStateChecker);
this.executor = Preconditions.checkNotNull(executor);
+ this.channelStateExecutorFactory = channelStateExecutorFactory;
// create the reader and writer structures
@@ -708,7 +714,8 @@ public class Task
taskManagerConfig,
metrics,
this,
- externalResourceInfoProvider);
+ externalResourceInfoProvider,
+ channelStateExecutorFactory);
// Make sure the user code classloader is accessible thread-locally.
// We are setting the correct context class loader before instantiating the invokable
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
index d135cbdd500..c2c3efd4b34 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateCheckpointWriterTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelSta
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
@@ -34,28 +35,42 @@ import org.apache.flink.util.function.RunnableWithException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;
+import javax.annotation.Nullable;
+
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.IOException;
import java.nio.file.Path;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.Iterator;
import java.util.Map;
import java.util.Random;
+import java.util.Set;
import java.util.stream.IntStream;
import static java.util.Collections.singletonList;
import static org.apache.flink.core.fs.Path.fromLocalFile;
import static org.apache.flink.core.fs.local.LocalFileSystem.getSharedInstance;
import static org.apache.flink.core.memory.MemorySegmentFactory.wrap;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertAllSubtaskDoneNormally;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertAllSubtaskNotDone;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertCheckpointFailureReason;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertHasSpecialCause;
import static org.apache.flink.runtime.state.CheckpointedStateScope.EXCLUSIVE;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.assertj.core.api.Assertions.fail;
/** {@link ChannelStateCheckpointWriter} test. */
class ChannelStateCheckpointWriterTest {
private static final RunnableWithException NO_OP_RUNNABLE = () -> {};
private final Random random = new Random();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
+ private static final SubtaskID SUBTASK_ID = SubtaskID.of(JOB_VERTEX_ID, SUBTASK_INDEX);
@TempDir private Path temporaryFolder;
@@ -89,8 +104,8 @@ class ChannelStateCheckpointWriterTest {
write(writer, channels[channel], getData(numBytesPerWrite));
}
}
- writer.completeInput();
- writer.completeOutput();
+ writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+ writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
assertThat(handle.getStateSize())
@@ -120,9 +135,9 @@ class ChannelStateCheckpointWriterTest {
new NetworkBuffer(
MemorySegmentFactory.allocateUnpooledSegment(threshold / 2),
FreeingBufferRecycler.INSTANCE);
- writer.writeInput(new InputChannelInfo(1, 2), buffer);
- writer.completeOutput();
- writer.completeInput();
+ writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, new InputChannelInfo(1, 2), buffer);
+ writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
+ writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
assertThat(result.isDone()).isTrue();
assertThat(checkpointsDir).isEmptyDirectory();
assertThat(sharedStateDir).isEmptyDirectory();
@@ -139,8 +154,8 @@ class ChannelStateCheckpointWriterTest {
}
};
ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult(), stream);
- writer.completeOutput();
- writer.completeInput();
+ writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
+ writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
assertThat(stream.isClosed()).isTrue();
}
@@ -149,9 +164,9 @@ class ChannelStateCheckpointWriterTest {
ChannelStateCheckpointWriter writer = createWriter(new ChannelStateWriteResult());
NetworkBuffer buffer =
new NetworkBuffer(
- MemorySegmentFactory.allocateUnpooledSegment(10, null),
+ MemorySegmentFactory.allocateUnpooledSegment(10),
FreeingBufferRecycler.INSTANCE);
- writer.writeInput(new InputChannelInfo(1, 2), buffer);
+ writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, new InputChannelInfo(1, 2), buffer);
assertThat(buffer.isRecycled()).isTrue();
}
@@ -174,29 +189,194 @@ class ChannelStateCheckpointWriterTest {
FlushRecorder dataStream = new FlushRecorder();
final ChannelStateCheckpointWriter writer =
new ChannelStateCheckpointWriter(
- "dummy task",
- 0,
+ Collections.singleton(SUBTASK_ID),
1L,
- new ChannelStateWriteResult(),
new ChannelStateSerializerImpl(),
NO_OP_RUNNABLE,
new MemoryCheckpointOutputStream(42),
dataStream);
-
- writer.completeInput();
- writer.completeOutput();
+ writer.registerSubtaskResult(SUBTASK_ID, new ChannelStateWriteResult());
+ writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+ writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
assertThat(dataStream.flushed).isTrue();
}
@Test
void testResultCompletion() throws Exception {
- ChannelStateWriteResult result = new ChannelStateWriteResult();
- ChannelStateCheckpointWriter writer = createWriter(result);
- writer.completeInput();
- assertThat(result.isDone()).isFalse();
- writer.completeOutput();
- assertThat(result.isDone()).isTrue();
+ for (int maxSubtasksPerChannelStateFile = 1;
+ maxSubtasksPerChannelStateFile < 10;
+ maxSubtasksPerChannelStateFile++) {
+ testMultiTaskCompletionAndAssertResult(maxSubtasksPerChannelStateFile);
+ }
+ }
+
+ private void testMultiTaskCompletionAndAssertResult(int maxSubtasksPerChannelStateFile)
+ throws Exception {
+ Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+ for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+ subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+ }
+ MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+ ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+ for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+ writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+ }
+
+ for (SubtaskID subtaskID : subtasks.keySet()) {
+ assertAllSubtaskNotDone(subtasks.values());
+ assertThat(stream.isClosed()).isFalse();
+ writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ assertAllSubtaskNotDone(subtasks.values());
+ assertThat(stream.isClosed()).isFalse();
+ writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ }
+ assertThat(stream.isClosed()).isTrue();
+ assertAllSubtaskDoneNormally(subtasks.values());
+ }
+
+ @Test
+ void testTaskUnregister() throws Exception {
+ testTaskUnregisterAndAssertResult(2);
+ testTaskUnregisterAndAssertResult(3);
+ testTaskUnregisterAndAssertResult(5);
+ testTaskUnregisterAndAssertResult(10);
+ }
+
+ private void testTaskUnregisterAndAssertResult(int maxSubtasksPerChannelStateFile)
+ throws Exception {
+ Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+ for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+ subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+ }
+ MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+ ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+ SubtaskID unregisterSubtask = null;
+ Iterator<Map.Entry<SubtaskID, ChannelStateWriteResult>> iterator =
+ subtasks.entrySet().iterator();
+ while (iterator.hasNext()) {
+ Map.Entry<SubtaskID, ChannelStateWriteResult> entry = iterator.next();
+ if (unregisterSubtask == null) {
+ unregisterSubtask = entry.getKey();
+ iterator.remove();
+ continue;
+ }
+ writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+ }
+
+ for (SubtaskID subtaskID : subtasks.keySet()) {
+ writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ }
+ assertAllSubtaskNotDone(subtasks.values());
+ assertThat(stream.isClosed()).isFalse();
+
+ assert unregisterSubtask != null;
+ writer.releaseSubtask(unregisterSubtask);
+ assertThat(stream.isClosed()).isTrue();
+ assertAllSubtaskDoneNormally(subtasks.values());
+ }
+
+ @Test
+ void testTaskFailThenCompleteOtherTask() {
+ testTaskFailAfterAllTaskRegisteredAndAssertResult(2);
+ testTaskFailAfterAllTaskRegisteredAndAssertResult(3);
+ testTaskFailAfterAllTaskRegisteredAndAssertResult(5);
+ testTaskFailAfterAllTaskRegisteredAndAssertResult(10);
+ }
+
+ private void testTaskFailAfterAllTaskRegisteredAndAssertResult(
+ int maxSubtasksPerChannelStateFile) {
+ Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+ for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+ subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+ }
+ MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+ ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+ SubtaskID firstSubtask = null;
+ for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+ if (firstSubtask == null) {
+ firstSubtask = entry.getKey();
+ }
+ writer.registerSubtaskResult(entry.getKey(), entry.getValue());
+ }
+ assertThat(stream.isClosed()).isFalse();
+
+ assert firstSubtask != null;
+ writer.fail(
+ firstSubtask.getJobVertexID(), firstSubtask.getSubtaskIndex(), new TestException());
+ assertThat(stream.isClosed()).isTrue();
+
+ for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+ if (firstSubtask.equals(entry.getKey())) {
+ assertHasSpecialCause(entry.getValue(), TestException.class);
+ continue;
+ }
+ assertCheckpointFailureReason(entry.getValue(), CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
+ }
+ }
+
+ @Test
+ void testCloseGetHandleThrowException() throws Exception {
+ Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+ for (int i = 0; i < 5; i++) {
+ subtasks.put(SubtaskID.of(new JobVertexID(), i), new ChannelStateWriteResult());
+ }
+ CloseExceptionOutputStream stream = new CloseExceptionOutputStream();
+ ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+ for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+ SubtaskID subtaskID = entry.getKey();
+ writer.registerSubtaskResult(subtaskID, entry.getValue());
+ NetworkBuffer buffer =
+ new NetworkBuffer(
+ MemorySegmentFactory.allocateUnpooledSegment(10),
+ FreeingBufferRecycler.INSTANCE);
+ writer.writeInput(
+ subtaskID.getJobVertexID(),
+ subtaskID.getSubtaskIndex(),
+ new InputChannelInfo(1, 2),
+ buffer);
+ }
+
+ for (SubtaskID subtaskID : subtasks.keySet()) {
+ assertAllSubtaskNotDone(subtasks.values());
+ assertThat(stream.isClosed()).isFalse();
+ writer.completeInput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ assertAllSubtaskNotDone(subtasks.values());
+ assertThat(stream.isClosed()).isFalse();
+ writer.completeOutput(subtaskID.getJobVertexID(), subtaskID.getSubtaskIndex());
+ }
+ assertThat(stream.isClosed()).isTrue();
+ for (Map.Entry<SubtaskID, ChannelStateWriteResult> entry : subtasks.entrySet()) {
+ assertThatThrownBy(() -> entry.getValue().getInputChannelStateHandles().get())
+ .cause()
+ .isInstanceOf(IOException.class)
+ .hasMessage("Test closeAndGetHandle exception.");
+ assertThatThrownBy(() -> entry.getValue().getResultSubpartitionStateHandles().get())
+ .cause()
+ .isInstanceOf(IOException.class)
+ .hasMessage("Test closeAndGetHandle exception.");
+ }
+ }
+
+ @Test
+ void testRegisterSubtaskAfterWriterDone() {
+ Map<SubtaskID, ChannelStateWriteResult> subtasks = new HashMap<>();
+ SubtaskID subtask0 = SubtaskID.of(JOB_VERTEX_ID, 0);
+ SubtaskID subtask1 = SubtaskID.of(JOB_VERTEX_ID, 1);
+ subtasks.put(subtask0, new ChannelStateWriteResult());
+ subtasks.put(subtask1, new ChannelStateWriteResult());
+ MemoryCheckpointOutputStream stream = new MemoryCheckpointOutputStream(1000);
+ ChannelStateCheckpointWriter writer = createWriter(stream, subtasks.keySet());
+ writer.fail(new JobVertexID(), 0, new TestException());
+ assertThatThrownBy(
+ () -> writer.registerSubtaskResult(subtask0, new ChannelStateWriteResult()))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("The write is done.");
+ assertThatThrownBy(
+ () -> writer.registerSubtaskResult(subtask1, new ChannelStateWriteResult()))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("The write is done.");
}
@Test
@@ -214,8 +394,8 @@ class ChannelStateCheckpointWriterTest {
write(writer, e.getKey(), getData(numBytes));
}
}
- writer.completeInput();
- writer.completeOutput();
+ writer.completeInput(JOB_VERTEX_ID, SUBTASK_INDEX);
+ writer.completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX);
for (InputChannelStateHandle handle : result.inputChannelStateHandles.get()) {
int headerSize = Integer.BYTES;
@@ -245,7 +425,7 @@ class ChannelStateCheckpointWriterTest {
FreeingBufferRecycler.INSTANCE,
Buffer.DataType.DATA_BUFFER,
segment.size());
- writer.writeInput(channelInfo, buffer);
+ writer.writeInput(JOB_VERTEX_ID, SUBTASK_INDEX, channelInfo, buffer);
}
private ChannelStateCheckpointWriter createWriter(ChannelStateWriteResult result) {
@@ -254,13 +434,28 @@ class ChannelStateCheckpointWriterTest {
private ChannelStateCheckpointWriter createWriter(
ChannelStateWriteResult result, CheckpointStateOutputStream stream) {
+ ChannelStateCheckpointWriter writer =
+ createWriter(stream, Collections.singleton(SUBTASK_ID));
+ writer.registerSubtaskResult(SUBTASK_ID, result);
+ return writer;
+ }
+
+ private ChannelStateCheckpointWriter createWriter(
+ CheckpointStateOutputStream stream, Set<SubtaskID> subtasks) {
return new ChannelStateCheckpointWriter(
- "dummy task",
- 0,
- 1L,
- result,
- stream,
- new ChannelStateSerializerImpl(),
- NO_OP_RUNNABLE);
+ subtasks, 1L, stream, new ChannelStateSerializerImpl(), NO_OP_RUNNABLE);
+ }
+}
+
+/** The output stream that throws an exception when close or closeAndGetHandle. */
+class CloseExceptionOutputStream extends MemoryCheckpointOutputStream {
+ public CloseExceptionOutputStream() {
+ super(1000);
+ }
+
+ @Nullable
+ @Override
+ public StreamStateHandle closeAndGetHandle() throws IOException {
+ throw new IOException("Test closeAndGetHandle exception.");
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
index 3d80f290856..5ece5423a39 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
@@ -19,29 +19,42 @@ package org.apache.flink.runtime.checkpoint.channel;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
+import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
-import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.junit.jupiter.api.Test;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
import java.util.function.Function;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHANNEL_STATE_SHARED_STREAM_EXCEPTION;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertCheckpointFailureReason;
+import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteResultUtil.assertHasSpecialCause;
import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.assertj.core.api.Assertions.assertThat;
/** {@link ChannelStateWriteRequestDispatcherImpl} test. */
class ChannelStateWriteRequestDispatcherImplTest {
+ private static final JobID JOB_ID = new JobID();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
+
@Test
void testPartialInputChannelStateWrite() throws Exception {
testBuffersRecycled(
buffers ->
ChannelStateWriteRequest.write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
1L,
new InputChannelInfo(1, 2),
ofElements(Buffer::recycleBuffer, buffers)));
@@ -52,52 +65,167 @@ class ChannelStateWriteRequestDispatcherImplTest {
testBuffersRecycled(
buffers ->
ChannelStateWriteRequest.write(
- 1L, new ResultSubpartitionInfo(1, 2), buffers));
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ 1L,
+ new ResultSubpartitionInfo(1, 2),
+ buffers));
+ }
+
+ private void testBuffersRecycled(
+ Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
+ ChannelStateWriteRequestDispatcher dispatcher =
+ new ChannelStateWriteRequestDispatcherImpl(
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
+ new ChannelStateSerializerImpl());
+ ChannelStateWriteResult result = new ChannelStateWriteResult();
+ dispatcher.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+ dispatcher.dispatch(
+ ChannelStateWriteRequest.start(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ 1L,
+ result,
+ CheckpointStorageLocationReference.getDefault()));
+
+ result.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
+ result.getInputChannelStateHandles().completeExceptionally(new TestException());
+
+ NetworkBuffer[] buffers = new NetworkBuffer[] {buffer(), buffer()};
+ dispatcher.dispatch(requestBuilder.apply(buffers));
+ for (NetworkBuffer buffer : buffers) {
+ assertThat(buffer.isRecycled()).isTrue();
+ }
}
@Test
- void testConcurrentUnalignedCheckpoint() throws Exception {
+ void testStartNewCheckpointForSameSubtask() throws Exception {
+ testStartNewCheckpointAndCheckOldCheckpointResult(false);
+ }
+
+ @Test
+ void testStartNewCheckpointForDifferentSubtask() throws Exception {
+ testStartNewCheckpointAndCheckOldCheckpointResult(true);
+ }
+
+ private void testStartNewCheckpointAndCheckOldCheckpointResult(boolean isDifferentSubtask)
+ throws Exception {
ChannelStateWriteRequestDispatcher processor =
new ChannelStateWriteRequestDispatcherImpl(
- "dummy task",
- 0,
- getStreamFactoryFactory(),
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
new ChannelStateSerializerImpl());
ChannelStateWriteResult result = new ChannelStateWriteResult();
+ processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+ JobVertexID newJobVertex = JOB_VERTEX_ID;
+ if (isDifferentSubtask) {
+ newJobVertex = new JobVertexID();
+ processor.dispatch(
+ ChannelStateWriteRequest.registerSubtask(newJobVertex, SUBTASK_INDEX));
+ }
processor.dispatch(
ChannelStateWriteRequest.start(
- 1L, result, CheckpointStorageLocationReference.getDefault()));
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ 1L,
+ result,
+ CheckpointStorageLocationReference.getDefault()));
assertThat(result.isDone()).isFalse();
processor.dispatch(
ChannelStateWriteRequest.start(
+ newJobVertex,
+ SUBTASK_INDEX,
2L,
new ChannelStateWriteResult(),
CheckpointStorageLocationReference.getDefault()));
- assertThat(result.isDone()).isTrue();
- assertThat(result.getInputChannelStateHandles()).isCompletedExceptionally();
- assertThat(result.getResultSubpartitionStateHandles()).isCompletedExceptionally();
+ assertCheckpointFailureReason(result, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
}
- private void testBuffersRecycled(
- Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
- ChannelStateWriteRequestDispatcher dispatcher =
+ @Test
+ void testStartOldCheckpointForSameSubtask() throws Exception {
+ testStartOldCheckpointAfterNewCheckpointAborted(false);
+ }
+
+ @Test
+ void testStartOldCheckpointForDifferentSubtask() throws Exception {
+ testStartOldCheckpointAfterNewCheckpointAborted(true);
+ }
+
+ private void testStartOldCheckpointAfterNewCheckpointAborted(boolean isDifferentSubtask)
+ throws Exception {
+ ChannelStateWriteRequestDispatcher processor =
new ChannelStateWriteRequestDispatcherImpl(
- "dummy task",
- 0,
- new MemoryBackendCheckpointStorageAccess(new JobID(), null, null, 1),
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
new ChannelStateSerializerImpl());
+ processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX));
+ JobVertexID newJobVertex = JOB_VERTEX_ID;
+ if (isDifferentSubtask) {
+ newJobVertex = new JobVertexID();
+ processor.dispatch(
+ ChannelStateWriteRequest.registerSubtask(newJobVertex, SUBTASK_INDEX));
+ }
+ processor.dispatch(
+ ChannelStateWriteRequest.abort(
+ JOB_VERTEX_ID, SUBTASK_INDEX, 2L, new TestException()));
+
ChannelStateWriteResult result = new ChannelStateWriteResult();
- dispatcher.dispatch(
+ processor.dispatch(
ChannelStateWriteRequest.start(
- 1L, result, CheckpointStorageLocationReference.getDefault()));
+ newJobVertex,
+ SUBTASK_INDEX,
+ 1L,
+ result,
+ CheckpointStorageLocationReference.getDefault()));
+ assertCheckpointFailureReason(result, CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED);
+ }
- result.getResultSubpartitionStateHandles().completeExceptionally(new TestException());
- result.getInputChannelStateHandles().completeExceptionally(new TestException());
+ @Test
+ void testAbortCheckpointAndCheckAllException() throws Exception {
+ testAbortCheckpointAndCheckAllException(1);
+ testAbortCheckpointAndCheckAllException(2);
+ testAbortCheckpointAndCheckAllException(3);
+ testAbortCheckpointAndCheckAllException(5);
+ testAbortCheckpointAndCheckAllException(10);
+ }
- NetworkBuffer[] buffers = new NetworkBuffer[] {buffer(), buffer()};
- dispatcher.dispatch(requestBuilder.apply(buffers));
- assertThat(buffers).allMatch(NetworkBuffer::isRecycled);
+ private void testAbortCheckpointAndCheckAllException(int numberOfSubtask) throws Exception {
+ ChannelStateWriteRequestDispatcher processor =
+ new ChannelStateWriteRequestDispatcherImpl(
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
+ new ChannelStateSerializerImpl());
+ List<ChannelStateWriteResult> results = new ArrayList<>(numberOfSubtask);
+ for (int i = 0; i < numberOfSubtask; i++) {
+ processor.dispatch(ChannelStateWriteRequest.registerSubtask(JOB_VERTEX_ID, i));
+ }
+ long checkpointId = 1L;
+ int abortedSubtaskIndex = new Random().nextInt(numberOfSubtask);
+ processor.dispatch(
+ ChannelStateWriteRequest.abort(
+ JOB_VERTEX_ID, abortedSubtaskIndex, checkpointId, new TestException()));
+ for (int i = 0; i < numberOfSubtask; i++) {
+ ChannelStateWriteResult result = new ChannelStateWriteResult();
+ results.add(result);
+ processor.dispatch(
+ ChannelStateWriteRequest.start(
+ JOB_VERTEX_ID,
+ i,
+ checkpointId,
+ result,
+ CheckpointStorageLocationReference.getDefault()));
+ }
+ assertThat(results).allMatch(ChannelStateWriteResult::isDone);
+ for (int i = 0; i < numberOfSubtask; i++) {
+ ChannelStateWriteResult result = results.get(i);
+ if (i == abortedSubtaskIndex) {
+ assertHasSpecialCause(result, TestException.class);
+ } else {
+ assertCheckpointFailureReason(result, CHANNEL_STATE_SHARED_STREAM_EXCEPTION);
+ }
+ }
}
private NetworkBuffer buffer() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
index 4b59e1235e6..36de469bbfa 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.java
@@ -17,12 +17,15 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
@@ -43,7 +46,6 @@ import static java.util.Optional.of;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeInput;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.completeOutput;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequest.write;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.assertj.core.api.Assertions.fail;
/** {@link ChannelStateWriteRequestDispatcherImpl} tests. */
@@ -90,6 +92,10 @@ public class ChannelStateWriteRequestDispatcherTest {
new Object[] {of(IllegalStateException.class), asList(start(), start())});
}
+ private static final JobID JOB_ID = new JobID();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
+
@Parameter public Optional<Class<Exception>> expectedException;
@Parameter(value = 1)
@@ -98,15 +104,17 @@ public class ChannelStateWriteRequestDispatcherTest {
private static final long CHECKPOINT_ID = 42L;
private static CheckpointInProgressRequest completeOut() {
- return completeOutput(CHECKPOINT_ID);
+ return completeOutput(JOB_VERTEX_ID, SUBTASK_INDEX, CHECKPOINT_ID);
}
private static CheckpointInProgressRequest completeIn() {
- return completeInput(CHECKPOINT_ID);
+ return completeInput(JOB_VERTEX_ID, SUBTASK_INDEX, CHECKPOINT_ID);
}
private static ChannelStateWriteRequest writeIn() {
return write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
CHECKPOINT_ID,
new InputChannelInfo(1, 1),
CloseableIterator.ofElement(
@@ -118,6 +126,8 @@ public class ChannelStateWriteRequestDispatcherTest {
private static ChannelStateWriteRequest writeOut() {
return write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
CHECKPOINT_ID,
new ResultSubpartitionInfo(1, 1),
new NetworkBuffer(
@@ -128,7 +138,12 @@ public class ChannelStateWriteRequestDispatcherTest {
private static ChannelStateWriteRequest writeOutFuture() {
CompletableFuture<List<Buffer>> outFuture = new CompletableFuture<>();
ChannelStateWriteRequest writeRequest =
- write(CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1), outFuture);
+ write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ CHECKPOINT_ID,
+ new ResultSubpartitionInfo(1, 1),
+ outFuture);
outFuture.complete(
singletonList(
new NetworkBuffer(
@@ -137,8 +152,14 @@ public class ChannelStateWriteRequestDispatcherTest {
return writeRequest;
}
+ private static SubtaskRegisterRequest register() {
+ return new SubtaskRegisterRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+ }
+
private static CheckpointStartRequest start() {
return new CheckpointStartRequest(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
CHECKPOINT_ID,
new ChannelStateWriteResult(),
new CheckpointStorageLocationReference(new byte[] {1}));
@@ -148,11 +169,11 @@ public class ChannelStateWriteRequestDispatcherTest {
void doRun() {
ChannelStateWriteRequestDispatcher processor =
new ChannelStateWriteRequestDispatcherImpl(
- "dummy task",
- 0,
- getStreamFactoryFactory(),
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
new ChannelStateSerializerImpl());
try {
+ processor.dispatch(register());
for (ChannelStateWriteRequest request : requests) {
processor.dispatch(request);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java
new file mode 100644
index 00000000000..4f67527be2d
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorFactoryTest.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.CheckpointStorage;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Random;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Test for {@link ChannelStateWriteRequestExecutorFactory} */
+public class ChannelStateWriteRequestExecutorFactoryTest {
+
+ private static final CheckpointStorage CHECKPOINT_STORAGE = new JobManagerCheckpointStorage();
+
+ @Test
+ void testReuseExecutorForSameJobId() {
+ assertReuseExecutor(1);
+ assertReuseExecutor(2);
+ assertReuseExecutor(3);
+ assertReuseExecutor(5);
+ assertReuseExecutor(10);
+ }
+
+ private void assertReuseExecutor(int maxSubtasksPerChannelStateFile) {
+ JobID JOB_ID = new JobID();
+ Random RANDOM = new Random();
+ ChannelStateWriteRequestExecutorFactory executorFactory =
+ new ChannelStateWriteRequestExecutorFactory(JOB_ID);
+ int numberOfTasks = 100;
+
+ ChannelStateWriteRequestExecutor currentExecutor = null;
+ for (int i = 0; i < numberOfTasks; i++) {
+ ChannelStateWriteRequestExecutor newExecutor =
+ executorFactory.getOrCreateExecutor(
+ new JobVertexID(),
+ RANDOM.nextInt(numberOfTasks),
+ CHECKPOINT_STORAGE,
+ maxSubtasksPerChannelStateFile);
+ if (i % maxSubtasksPerChannelStateFile == 0) {
+ assertThat(newExecutor)
+ .as("Factory should create the new executor.")
+ .isNotSameAs(currentExecutor);
+ currentExecutor = newExecutor;
+ } else {
+ assertThat(newExecutor)
+ .as("Factory should reuse the old executor.")
+ .isSameAs(currentExecutor);
+ }
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
index b7aa7afb682..5c248c73895 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestExecutorImplTest.java
@@ -17,21 +17,30 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.function.BiConsumerWithException;
import org.junit.jupiter.api.Test;
import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
import java.io.IOException;
+import java.util.ArrayDeque;
import java.util.Collections;
+import java.util.Deque;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.LinkedBlockingDeque;
+import java.util.concurrent.atomic.AtomicInteger;
import static org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestDispatcher.NO_OP;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.util.ExceptionUtils.findThrowable;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -40,7 +49,9 @@ import static org.assertj.core.api.Assertions.fail;
/** {@link ChannelStateWriteRequestExecutorImpl} test. */
class ChannelStateWriteRequestExecutorImplTest {
- private static final String TASK_NAME = "test task";
+ private static final JobID JOB_ID = new JobID();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
@Test
void testCloseAfterSubmit() {
@@ -74,9 +85,10 @@ class ChannelStateWriteRequestExecutorImplTest {
throws Exception {
WorkerClosingDeque closingDeque = new WorkerClosingDeque();
ChannelStateWriteRequestExecutorImpl worker =
- new ChannelStateWriteRequestExecutorImpl(TASK_NAME, NO_OP, closingDeque);
+ new ChannelStateWriteRequestExecutorImpl(NO_OP, closingDeque, 5, e -> {});
closingDeque.setWorker(worker);
- TestWriteRequest request = new TestWriteRequest();
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
requestFun.accept(worker, request);
assertThat(closingDeque).isEmpty();
assertThat(request.isCancelled()).isFalse();
@@ -87,11 +99,13 @@ class ChannelStateWriteRequestExecutorImplTest {
ChannelStateWriteRequestExecutor, ChannelStateWriteRequest, Exception>
submitAction)
throws Exception {
- TestWriteRequest request = new TestWriteRequest();
- LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
+ TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+ Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
try {
- submitAction.accept(
- new ChannelStateWriteRequestExecutorImpl(TASK_NAME, NO_OP, deque), request);
+ ChannelStateWriteRequestExecutorImpl executor =
+ new ChannelStateWriteRequestExecutorImpl(NO_OP, deque, 5, e -> {});
+ executor.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ submitAction.accept(executor, request);
} catch (IllegalStateException e) {
// expected: executor not started;
return;
@@ -105,14 +119,14 @@ class ChannelStateWriteRequestExecutorImplTest {
@Test
@SuppressWarnings("CallToThreadRun")
void testCleanup() throws IOException {
- TestWriteRequest request = new TestWriteRequest();
- LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
+ TestWriteRequest request = new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX);
+ Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
deque.add(request);
TestRequestDispatcher requestProcessor = new TestRequestDispatcher();
ChannelStateWriteRequestExecutorImpl worker =
- new ChannelStateWriteRequestExecutorImpl(TASK_NAME, requestProcessor, deque);
-
- worker.close();
+ new ChannelStateWriteRequestExecutorImpl(requestProcessor, deque, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
worker.run();
assertThat(requestProcessor.isStopped()).isTrue();
@@ -123,16 +137,20 @@ class ChannelStateWriteRequestExecutorImplTest {
@Test
void testIgnoresInterruptsWhileRunning() throws Exception {
TestRequestDispatcher requestProcessor = new TestRequestDispatcher();
- LinkedBlockingDeque<ChannelStateWriteRequest> deque = new LinkedBlockingDeque<>();
- try (ChannelStateWriteRequestExecutorImpl worker =
- new ChannelStateWriteRequestExecutorImpl(TASK_NAME, requestProcessor, deque)) {
+ Deque<ChannelStateWriteRequest> deque = new ArrayDeque<>();
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(requestProcessor, deque, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ try {
worker.start();
worker.getThread().interrupt();
- worker.submit(new TestWriteRequest());
+ worker.submit(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
worker.getThread().interrupt();
while (!deque.isEmpty()) {
Thread.sleep(100);
}
+ } finally {
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
}
}
@@ -141,28 +159,123 @@ class ChannelStateWriteRequestExecutorImplTest {
long checkpointId = 1L;
ChannelStateWriteRequestDispatcher processor =
new ChannelStateWriteRequestDispatcherImpl(
- "dummy task",
- 0,
- getStreamFactoryFactory(),
+ new JobManagerCheckpointStorage(),
+ JOB_ID,
new ChannelStateSerializerImpl());
- try (ChannelStateWriteRequestExecutorImpl worker =
- new ChannelStateWriteRequestExecutorImpl(TASK_NAME, processor)) {
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(processor, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ try {
worker.start();
worker.submit(
new CheckpointStartRequest(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
checkpointId,
new ChannelStateWriter.ChannelStateWriteResult(),
CheckpointStorageLocationReference.getDefault()));
worker.submit(
ChannelStateWriteRequest.write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
checkpointId,
new ResultSubpartitionInfo(0, 0),
new CompletableFuture<>()));
worker.submit(
ChannelStateWriteRequest.write(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
checkpointId,
new ResultSubpartitionInfo(0, 0),
new CompletableFuture<>()));
+ } finally {
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ }
+ }
+
+ @Test
+ void testSkipUnreadyDataFuture() throws Exception {
+ int subtaskIndex0 = 0;
+ int subtaskIndex1 = 1;
+
+ Queue<ChannelStateWriteRequest> firstBatchRequests = new LinkedList<>();
+ Queue<ChannelStateWriteRequest> secondBatchRequests = new LinkedList<>();
+ CompletableFuture<List<Buffer>> dataFuture = new CompletableFuture<>();
+ int firstBatchSubtask1Count = 3;
+ int subtask0Count = 4;
+ int subtask1Count = 4;
+
+ {
+ // Generate the first batch requests.
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+ // Add a data future request, all subsequent requests of subtaskIndex0 should be blocked
+ // before this future is completed.
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0, dataFuture));
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+ firstBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+
+ // Generate the second batch requests.
+ secondBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex0));
+ secondBatchRequests.add(new TestWriteRequest(JOB_VERTEX_ID, subtaskIndex1));
+ }
+
+ CompletableFuture<Void> firstBatchFuture = new CompletableFuture<>();
+ CompletableFuture<Void> allReceivedFuture = new CompletableFuture<>();
+
+ // The subtask register request cannot be count.
+ AtomicInteger subtask0ReceivedCounter = new AtomicInteger(-1);
+ AtomicInteger subtask1ReceivedCounter = new AtomicInteger(-1);
+
+ TestRequestDispatcher throwingRequestProcessor =
+ new TestRequestDispatcher() {
+ @Override
+ public void dispatch(ChannelStateWriteRequest request) {
+ if (request.getSubtaskIndex() == subtaskIndex0) {
+ subtask0ReceivedCounter.incrementAndGet();
+ } else if (request.getSubtaskIndex() == subtaskIndex1) {
+ if (subtask1ReceivedCounter.incrementAndGet()
+ == firstBatchSubtask1Count) {
+ firstBatchFuture.complete(null);
+ }
+ } else {
+ throw new IllegalStateException(
+ String.format(
+ "Unknown subtask index %s.",
+ request.getSubtaskIndex()));
+ }
+ if (subtask0ReceivedCounter.get() == subtask0Count
+ && subtask1ReceivedCounter.get() == subtask1Count) {
+ allReceivedFuture.complete(null);
+ }
+ }
+ };
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(throwingRequestProcessor, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, subtaskIndex0);
+ worker.registerSubtask(JOB_VERTEX_ID, subtaskIndex1);
+ try {
+ worker.start();
+ // start the first batch
+ for (ChannelStateWriteRequest request : firstBatchRequests) {
+ worker.submit(request);
+ }
+ firstBatchFuture.get();
+ assertThat(subtask0ReceivedCounter.get()).isOne();
+ assertThat(subtask1ReceivedCounter.get()).isEqualTo(firstBatchSubtask1Count);
+
+ // start the second batch
+ for (ChannelStateWriteRequest request : secondBatchRequests) {
+ worker.submit(request);
+ }
+ dataFuture.complete(Collections.emptyList());
+ allReceivedFuture.get();
+ assertThat(subtask0ReceivedCounter.get()).isEqualTo(subtask0Count);
+ assertThat(subtask1ReceivedCounter.get()).isEqualTo(subtask1Count);
+ } finally {
+ worker.releaseSubtask(JOB_VERTEX_ID, subtaskIndex0);
+ worker.releaseSubtask(JOB_VERTEX_ID, subtaskIndex1);
}
}
@@ -176,14 +289,17 @@ class ChannelStateWriteRequestExecutorImplTest {
throw testException;
}
};
- LinkedBlockingDeque<ChannelStateWriteRequest> deque =
- new LinkedBlockingDeque<>(Collections.singletonList(new TestWriteRequest()));
+ Deque<ChannelStateWriteRequest> deque =
+ new ArrayDeque<>(
+ Collections.singletonList(
+ new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX)));
ChannelStateWriteRequestExecutorImpl worker =
new ChannelStateWriteRequestExecutorImpl(
- TASK_NAME, throwingRequestProcessor, deque);
+ throwingRequestProcessor, deque, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
worker.run();
try {
- worker.close();
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
} catch (IOException e) {
if (findThrowable(e, TestException.class)
.filter(found -> found == testException)
@@ -196,12 +312,136 @@ class ChannelStateWriteRequestExecutorImplTest {
fail("exception not thrown");
}
- private static class TestWriteRequest implements ChannelStateWriteRequest {
+ @Test
+ void testSubmitRequestOfUnregisteredSubtask() throws Exception {
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(NO_OP, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ worker.start();
+ worker.submit(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
+
+ assertThatThrownBy(
+ () -> worker.submit(new TestWriteRequest(new JobVertexID(), SUBTASK_INDEX)))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("is not yet registered.");
+
+ assertThatThrownBy(
+ () ->
+ worker.submitPriority(
+ new TestWriteRequest(new JobVertexID(), SUBTASK_INDEX)))
+ .isInstanceOf(IllegalArgumentException.class)
+ .hasMessageContaining("is not yet registered.");
+ }
+
+ @Test
+ void testSubmitPriorityUnreadyRequest() throws Exception {
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(NO_OP, 5, e -> {});
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ worker.start();
+ worker.submitPriority(new TestWriteRequest(JOB_VERTEX_ID, SUBTASK_INDEX));
+
+ assertThatThrownBy(
+ () ->
+ worker.submitPriority(
+ new TestWriteRequest(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ new CompletableFuture<>())))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("The priority request must be ready.");
+ }
+
+ @Test
+ void testRegisterSubtaskAfterRegisterCompleted() throws Exception {
+ int maxSubtasksPerChannelStateFile = 5;
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(
+ NO_OP, maxSubtasksPerChannelStateFile, e -> {});
+ for (int i = 0; i < maxSubtasksPerChannelStateFile; i++) {
+ assertThat(worker.isRegistering()).isTrue();
+ worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX);
+ }
+ assertThat(worker.isRegistering()).isFalse();
+ assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("This executor has been registered.");
+ }
+
+ @Test
+ void testSubmitStartRequestBeforeRegisterCompleted() throws Exception {
+ CompletableFuture<Void> dispatcherFuture = new CompletableFuture<>();
+ TestRequestDispatcher dispatcher =
+ new TestRequestDispatcher() {
+ @Override
+ public void dispatch(ChannelStateWriteRequest request) {
+ if (request instanceof CheckpointStartRequest) {
+ dispatcherFuture.complete(null);
+ }
+ }
+ };
+ int maxSubtasksPerChannelStateFile = 5;
+ CompletableFuture<ChannelStateWriteRequestExecutor> workerFuture =
+ new CompletableFuture<>();
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(
+ dispatcher, maxSubtasksPerChannelStateFile, workerFuture::complete);
+ worker.start();
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ assertThat(worker.isRegistering()).isTrue();
+
+ worker.submit(
+ ChannelStateWriteRequest.start(
+ JOB_VERTEX_ID,
+ SUBTASK_INDEX,
+ 1,
+ new ChannelStateWriter.ChannelStateWriteResult(),
+ CheckpointStorageLocationReference.getDefault()));
+ dispatcherFuture.get();
+ assertThat(worker.isRegistering()).isFalse();
+ assertThat(workerFuture).isCompletedWithValue(worker);
+ assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("This executor has been registered.");
+ }
+
+ @Test
+ void testReleaseSubtaskBeforeRegisterCompleted() throws Exception {
+ int maxSubtasksPerChannelStateFile = 5;
+ CompletableFuture<ChannelStateWriteRequestExecutor> workerFuture =
+ new CompletableFuture<>();
+ ChannelStateWriteRequestExecutorImpl worker =
+ new ChannelStateWriteRequestExecutorImpl(
+ new TestRequestDispatcher(),
+ maxSubtasksPerChannelStateFile,
+ workerFuture::complete);
+ worker.start();
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ assertThat(worker.isRegistering()).isTrue();
+
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
+ assertThat(worker.isRegistering()).isFalse();
+ assertThat(workerFuture).isCompletedWithValue(worker);
+ assertThatThrownBy(() -> worker.registerSubtask(new JobVertexID(), SUBTASK_INDEX))
+ .isInstanceOf(IllegalStateException.class)
+ .hasMessage("This executor has been registered.");
+ }
+
+ private static class TestWriteRequest extends ChannelStateWriteRequest {
private boolean cancelled = false;
- @Override
- public long getCheckpointId() {
- return 0;
+ @Nullable private final CompletableFuture<?> readyFuture;
+
+ public TestWriteRequest(JobVertexID jobVertexID, int subtaskIndex) {
+ this(jobVertexID, subtaskIndex, null);
+ }
+
+ public TestWriteRequest(
+ JobVertexID jobVertexID,
+ int subtaskIndex,
+ @Nullable CompletableFuture<?> readyFuture) {
+ super(jobVertexID, subtaskIndex, 0, "Test");
+ this.readyFuture = readyFuture;
}
@Override
@@ -212,27 +452,35 @@ class ChannelStateWriteRequestExecutorImplTest {
public boolean isCancelled() {
return cancelled;
}
+
+ @Override
+ public CompletableFuture<?> getReadyFuture() {
+ if (readyFuture != null) {
+ return readyFuture;
+ }
+ return super.getReadyFuture();
+ }
}
- private static class WorkerClosingDeque extends LinkedBlockingDeque<ChannelStateWriteRequest> {
+ private static class WorkerClosingDeque extends ArrayDeque<ChannelStateWriteRequest> {
private ChannelStateWriteRequestExecutor worker;
@Override
- public void put(@Nonnull ChannelStateWriteRequest request) throws InterruptedException {
- super.putFirst(request);
+ public boolean add(@Nonnull ChannelStateWriteRequest request) {
+ boolean add = super.add(request);
try {
- worker.close();
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
} catch (IOException e) {
ExceptionUtils.rethrow(e);
}
+ return add;
}
@Override
- public void putFirst(@Nonnull ChannelStateWriteRequest request)
- throws InterruptedException {
- super.putFirst(request);
+ public void addFirst(@Nonnull ChannelStateWriteRequest request) {
+ super.addFirst(request);
try {
- worker.close();
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
} catch (IOException e) {
ExceptionUtils.rethrow(e);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java
new file mode 100644
index 00000000000..ee6a51a684f
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteResultUtil.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint.channel;
+
+import org.apache.flink.runtime.checkpoint.CheckpointException;
+import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
+
+import java.util.Collection;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
+
+public class ChannelStateWriteResultUtil {
+
+ public static void assertHasSpecialCause(
+ ChannelStateWriter.ChannelStateWriteResult result, Class<? extends Throwable> type) {
+ assertThatThrownBy(() -> result.getInputChannelStateHandles().get())
+ .hasCauseInstanceOf(type);
+ assertThatThrownBy(() -> result.getResultSubpartitionStateHandles().get())
+ .hasCauseInstanceOf(type);
+ }
+
+ public static void assertCheckpointFailureReason(
+ ChannelStateWriter.ChannelStateWriteResult result,
+ CheckpointFailureReason checkpointFailureReason) {
+ assertThatThrownBy(() -> result.getInputChannelStateHandles().get())
+ .cause()
+ .asInstanceOf(type(CheckpointException.class))
+ .satisfies(
+ checkpointException ->
+ assertThat(checkpointException.getCheckpointFailureReason())
+ .isEqualTo(checkpointFailureReason));
+
+ assertThatThrownBy(() -> result.getResultSubpartitionStateHandles().get())
+ .cause()
+ .asInstanceOf(type(CheckpointException.class))
+ .satisfies(
+ checkpointException ->
+ assertThat(checkpointException.getCheckpointFailureReason())
+ .isEqualTo(checkpointFailureReason));
+ }
+
+ public static void assertAllSubtaskDoneNormally(
+ Collection<ChannelStateWriter.ChannelStateWriteResult> results) {
+ assertThat(results)
+ .allMatch(ChannelStateWriter.ChannelStateWriteResult::isDone)
+ .allMatch(
+ result -> !result.getInputChannelStateHandles().isCompletedExceptionally())
+ .allMatch(
+ result ->
+ !result.getResultSubpartitionStateHandles()
+ .isCompletedExceptionally());
+ }
+
+ public static void assertAllSubtaskNotDone(
+ Collection<ChannelStateWriter.ChannelStateWriteResult> results) {
+ assertThat(results).allMatch(result -> !result.isDone());
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index fcf9ad65f2a..0b2cbe3cc6c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -17,12 +17,15 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.function.BiConsumerWithException;
import org.junit.jupiter.api.Test;
@@ -33,7 +36,6 @@ import java.util.Deque;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -42,6 +44,9 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
class ChannelStateWriterImplTest {
private static final long CHECKPOINT_ID = 42L;
private static final String TASK_NAME = "test";
+ private static final JobID JOB_ID = new JobID();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
@Test
void testAddEventBuffer() throws Exception {
@@ -141,17 +146,20 @@ class ChannelStateWriterImplTest {
}
@Test
- void testBuffersRecycledOnError() throws IOException {
+ void testBuffersRecycledOnError() {
NetworkBuffer buffer = getBuffer();
- try (ChannelStateWriterImpl writer =
+ ChannelStateWriterImpl writer =
new ChannelStateWriterImpl(
- TASK_NAME, new ConcurrentHashMap<>(), failingWorker(), 5)) {
- writer.open();
- assertThatThrownBy(() -> callAddInputData(writer, buffer))
- .isInstanceOf(RuntimeException.class)
- .hasCauseInstanceOf(TestException.class);
- assertThat(buffer.isRecycled()).isTrue();
- }
+ JOB_VERTEX_ID,
+ TASK_NAME,
+ SUBTASK_INDEX,
+ new ConcurrentHashMap<>(),
+ failingWorker(),
+ 5);
+ assertThatThrownBy(() -> callAddInputData(writer, buffer))
+ .isInstanceOf(RuntimeException.class)
+ .hasCauseInstanceOf(TestException.class);
+ assertThat(buffer.isRecycled()).isTrue();
}
@Test
@@ -210,10 +218,17 @@ class ChannelStateWriterImplTest {
@Test
void testRethrowOnNextCall() {
- SyncChannelStateWriteRequestExecutor worker = new SyncChannelStateWriteRequestExecutor();
+ SyncChannelStateWriteRequestExecutor worker =
+ new SyncChannelStateWriteRequestExecutor(JOB_ID);
ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(TASK_NAME, new ConcurrentHashMap<>(), worker, 5);
- writer.open();
+ new ChannelStateWriterImpl(
+ JOB_VERTEX_ID,
+ TASK_NAME,
+ SUBTASK_INDEX,
+ new ConcurrentHashMap<>(),
+ worker,
+ 5);
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
worker.setThrown(new TestException());
assertThatThrownBy(() -> callStart(writer)).hasCauseInstanceOf(TestException.class);
}
@@ -223,8 +238,13 @@ class ChannelStateWriterImplTest {
int maxCheckpoints = 3;
try (ChannelStateWriterImpl writer =
new ChannelStateWriterImpl(
- TASK_NAME, 0, getStreamFactoryFactory(), maxCheckpoints)) {
- writer.open();
+ JOB_VERTEX_ID,
+ TASK_NAME,
+ SUBTASK_INDEX,
+ new JobManagerCheckpointStorage(),
+ maxCheckpoints,
+ new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+ 5)) {
for (int i = 0; i < maxCheckpoints; i++) {
writer.start(i, CheckpointOptions.forCheckpointWithDefaultLocation());
}
@@ -237,15 +257,6 @@ class ChannelStateWriterImplTest {
}
}
- @Test
- void testStartNotOpened() throws IOException {
- try (ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory())) {
- assertThatThrownBy(() -> callStart(writer))
- .hasCauseInstanceOf(IllegalStateException.class);
- }
- }
-
@Test
void testNoStartAfterClose() throws IOException {
ChannelStateWriterImpl writer = openWriter();
@@ -274,8 +285,6 @@ class ChannelStateWriterImplTest {
private ChannelStateWriteRequestExecutor failingWorker() {
return new ChannelStateWriteRequestExecutor() {
- @Override
- public void close() {}
@Override
public void submit(ChannelStateWriteRequest e) {
@@ -289,6 +298,12 @@ class ChannelStateWriterImplTest {
@Override
public void start() throws IllegalStateException {}
+
+ @Override
+ public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {}
+
+ @Override
+ public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {}
};
}
@@ -306,21 +321,31 @@ class ChannelStateWriterImplTest {
ChannelStateWriter, SyncChannelStateWriteRequestExecutor, Exception>
testFn)
throws Exception {
- try (SyncChannelStateWriteRequestExecutor worker =
- new SyncChannelStateWriteRequestExecutor();
- ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(
- TASK_NAME, new ConcurrentHashMap<>(), worker, 5)) {
- writer.open();
+ SyncChannelStateWriteRequestExecutor worker =
+ new SyncChannelStateWriteRequestExecutor(JOB_ID);
+ try (ChannelStateWriterImpl writer =
+ new ChannelStateWriterImpl(
+ JOB_VERTEX_ID,
+ TASK_NAME,
+ SUBTASK_INDEX,
+ new ConcurrentHashMap<>(),
+ worker,
+ 5)) {
+ worker.registerSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
testFn.accept(writer, worker);
+ } finally {
+ worker.releaseSubtask(JOB_VERTEX_ID, SUBTASK_INDEX);
}
}
private ChannelStateWriterImpl openWriter() {
- ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(TASK_NAME, 0, getStreamFactoryFactory());
- writer.open();
- return writer;
+ return new ChannelStateWriterImpl(
+ JOB_VERTEX_ID,
+ TASK_NAME,
+ SUBTASK_INDEX,
+ new JobManagerCheckpointStorage(),
+ new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+ 5);
}
private void callStart(ChannelStateWriter writer) {
@@ -352,14 +377,11 @@ class SyncChannelStateWriteRequestExecutor implements ChannelStateWriteRequestEx
private final Deque<ChannelStateWriteRequest> deque;
private Exception thrown;
- SyncChannelStateWriteRequestExecutor() {
+ SyncChannelStateWriteRequestExecutor(JobID jobID) {
deque = new ArrayDeque<>();
requestProcessor =
new ChannelStateWriteRequestDispatcherImpl(
- "dummy task",
- 0,
- getStreamFactoryFactory(),
- new ChannelStateSerializerImpl());
+ new JobManagerCheckpointStorage(), jobID, new ChannelStateSerializerImpl());
}
@Override
@@ -382,7 +404,14 @@ class SyncChannelStateWriteRequestExecutor implements ChannelStateWriteRequestEx
public void start() throws IllegalStateException {}
@Override
- public void close() {}
+ public void registerSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+ deque.add(ChannelStateWriteRequest.registerSubtask(jobVertexID, subtaskIndex));
+ }
+
+ @Override
+ public void releaseSubtask(JobVertexID jobVertexID, int subtaskIndex) {
+ deque.add(ChannelStateWriteRequest.releaseSubtask(jobVertexID, subtaskIndex));
+ }
void processAllRequests() throws Exception {
while (!deque.isEmpty()) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index 12a87c14183..0d4fe2b14e5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -17,6 +17,8 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+
import org.junit.jupiter.api.Test;
import java.util.concurrent.CyclicBarrier;
@@ -62,6 +64,8 @@ class CheckpointInProgressRequestTest {
AtomicInteger cancelCounter, CyclicBarrier cb) {
return new CheckpointInProgressRequest(
"test",
+ new JobVertexID(),
+ 0,
1L,
unused -> {},
unused -> {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 260e83331bb..ef35e7a9218 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -73,6 +74,8 @@ public class DummyEnvironment implements Environment {
private final AccumulatorRegistry accumulatorRegistry;
private UserCodeClassLoader userClassLoader;
private final Configuration taskConfiguration = new Configuration();
+ private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory =
+ new ChannelStateWriteRequestExecutorFactory(jobId);
public DummyEnvironment() {
this("Test Job", 1, 0, 1);
@@ -268,4 +271,9 @@ public class DummyEnvironment implements Environment {
public TaskOperatorEventGateway getOperatorCoordinatorEventGateway() {
return new NoOpTaskOperatorEventGateway();
}
+
+ @Override
+ public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+ return channelStateExecutorFactory;
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index a5a53780b6c..f87b9d417f4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -135,6 +136,8 @@ public class MockEnvironment implements Environment, AutoCloseable {
private CheckpointStorageAccess checkpointStorageAccess;
+ private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
public static MockEnvironmentBuilder builder() {
return new MockEnvironmentBuilder();
}
@@ -157,7 +160,8 @@ public class MockEnvironment implements Environment, AutoCloseable {
TaskMetricGroup taskMetricGroup,
TaskManagerRuntimeInfo taskManagerRuntimeInfo,
MemoryManager memManager,
- ExternalResourceInfoProvider externalResourceInfoProvider) {
+ ExternalResourceInfoProvider externalResourceInfoProvider,
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
this.jobID = jobID;
this.jobVertexID = jobVertexID;
@@ -195,6 +199,7 @@ public class MockEnvironment implements Environment, AutoCloseable {
this.mainMailboxExecutor = new SyncMailboxExecutor();
this.asyncOperationsThreadPool = Executors.newDirectExecutorService();
+ this.channelStateExecutorFactory = channelStateExecutorFactory;
}
public IteratorWrappingTestSingleInputGate<Record> addInput(
@@ -441,6 +446,11 @@ public class MockEnvironment implements Environment, AutoCloseable {
return checkNotNull(checkpointStorageAccess);
}
+ @Override
+ public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+ return channelStateExecutorFactory;
+ }
+
public void setExpectedExternalFailureCause(Class<? extends Throwable> expectedThrowableClass) {
this.expectedExternalFailureCause = Optional.of(expectedThrowableClass);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
index d3e723d05bf..ec61b9a6fa5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironmentBuilder.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.operators.testutils;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.io.disk.iomanager.IOManagerAsync;
@@ -61,6 +62,8 @@ public class MockEnvironmentBuilder {
buildMemoryManager(1024 * MemoryManager.DEFAULT_PAGE_SIZE);
private ExternalResourceInfoProvider externalResourceInfoProvider =
ExternalResourceInfoProvider.NO_EXTERNAL_RESOURCES;
+ private ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory =
+ new ChannelStateWriteRequestExecutorFactory(jobID);
private MemoryManager buildMemoryManager(long memorySize) {
return MemoryManagerBuilder.newBuilder().setMemorySize(memorySize).build();
@@ -159,6 +162,12 @@ public class MockEnvironmentBuilder {
return this;
}
+ public MockEnvironmentBuilder setChannelStateWriteRequestExecutorFactory(
+ ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory) {
+ this.channelStateExecutorFactory = channelStateExecutorFactory;
+ return this;
+ }
+
public MockEnvironment build() {
if (ioManager == null) {
ioManager = new IOManagerAsync();
@@ -181,6 +190,7 @@ public class MockEnvironmentBuilder {
taskMetricGroup,
taskManagerRuntimeInfo,
memoryManager,
- externalResourceInfoProvider);
+ externalResourceInfoProvider,
+ channelStateExecutorFactory);
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
index 7f1518f529c..fe205a345ea 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/ChannelPersistenceITCase.java
@@ -17,11 +17,13 @@
package org.apache.flink.runtime.state;
+import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter.ChannelStateWriteResult;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
@@ -44,8 +46,9 @@ import org.apache.flink.runtime.io.network.partition.consumer.InputChannelBuilde
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.util.function.SupplierWithException;
import org.junit.Test;
@@ -75,6 +78,9 @@ import static org.junit.Assert.assertNull;
/** ChannelPersistenceITCase. */
public class ChannelPersistenceITCase {
private static final Random RANDOM = new Random(System.currentTimeMillis());
+ private static final JobID JOB_ID = new JobID();
+ private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+ private static final int SUBTASK_INDEX = 0;
@Test
public void testUpstreamBlocksAfterRecoveringState() throws Exception {
@@ -250,8 +256,13 @@ public class ChannelPersistenceITCase {
Map<ResultSubpartitionInfo, Buffer> rsBuffers = wrapWithBuffers(rsMap);
Map<ResultSubpartitionInfo, Buffer> rsFutureBuffers = wrapWithBuffers(rsFutureMap);
try (ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl("test", 0, getStreamFactoryFactory(maxStateSize))) {
- writer.open();
+ new ChannelStateWriterImpl(
+ JOB_VERTEX_ID,
+ "test",
+ SUBTASK_INDEX,
+ new JobManagerCheckpointStorage(maxStateSize),
+ new ChannelStateWriteRequestExecutorFactory(JOB_ID),
+ 5)) {
writer.start(
checkpointId,
new CheckpointOptions(
@@ -283,30 +294,6 @@ public class ChannelPersistenceITCase {
}
}
- public static CheckpointStorageWorkerView getStreamFactoryFactory() {
- return getStreamFactoryFactory(42);
- }
-
- public static CheckpointStorageWorkerView getStreamFactoryFactory(int maxStateSize) {
- return new CheckpointStorageWorkerView() {
- @Override
- public CheckpointStreamFactory resolveCheckpointStorageLocation(
- long checkpointId, CheckpointStorageLocationReference reference) {
- return new NonPersistentMetadataCheckpointStorageLocation(maxStateSize);
- }
-
- @Override
- public CheckpointStateOutputStream createTaskOwnedStateStream() {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public CheckpointStateToolset createTaskOwnedCheckpointStateToolset() {
- throw new UnsupportedOperationException();
- }
- };
- }
-
private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriteResult t) throws Exception {
return new TaskStateSnapshot(
singletonMap(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java
new file mode 100644
index 00000000000..32b2e28528e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/TaskExecutorChannelStateExecutorFactoryManagerTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/** Tests for {@link TaskExecutorChannelStateExecutorFactoryManager}. */
+public class TaskExecutorChannelStateExecutorFactoryManagerTest {
+
+ @Test
+ void testReuseFactory() {
+ TaskExecutorChannelStateExecutorFactoryManager manager =
+ new TaskExecutorChannelStateExecutorFactoryManager();
+
+ JobID jobID = new JobID();
+ ChannelStateWriteRequestExecutorFactory factory = manager.getOrCreateExecutorFactory(jobID);
+ assertThat(manager.getOrCreateExecutorFactory(jobID))
+ .as("Same job should share the executor factory.")
+ .isSameAs(factory);
+
+ assertThat(manager.getOrCreateExecutorFactory(new JobID()))
+ .as("Different jobs cannot share executor factory.")
+ .isNotSameAs(factory);
+ manager.shutdown();
+ }
+
+ @Test
+ void testReleaseForJob() {
+ TaskExecutorChannelStateExecutorFactoryManager manager =
+ new TaskExecutorChannelStateExecutorFactoryManager();
+
+ JobID jobID = new JobID();
+ assertThat(manager.getFactoryByJobId(jobID)).isNull();
+ manager.getOrCreateExecutorFactory(jobID);
+ assertThat(manager.getFactoryByJobId(jobID)).isNotNull();
+
+ manager.releaseResourcesForJob(jobID);
+ assertThat(manager.getFactoryByJobId(jobID)).isNull();
+ manager.shutdown();
+ }
+
+ @Test
+ void testShutdown() {
+ TaskExecutorChannelStateExecutorFactoryManager manager =
+ new TaskExecutorChannelStateExecutorFactoryManager();
+
+ JobID jobID = new JobID();
+ manager.getOrCreateExecutorFactory(jobID);
+ manager.shutdown();
+
+ assertThatThrownBy(() -> manager.getOrCreateExecutorFactory(jobID))
+ .isInstanceOf(IllegalStateException.class);
+ assertThatThrownBy(() -> manager.getOrCreateExecutorFactory(new JobID()))
+ .isInstanceOf(IllegalStateException.class);
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
index cfbf6548dd5..93671d14d4d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskexecutor/TaskManagerServicesBuilder.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.memory.SharedResources;
import org.apache.flink.runtime.query.KvStateRegistry;
import org.apache.flink.runtime.registration.RetryingRegistrationConfiguration;
import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.state.TaskExecutorChannelStateExecutorFactoryManager;
import org.apache.flink.runtime.state.TaskExecutorLocalStateStoresManager;
import org.apache.flink.runtime.state.TaskExecutorStateChangelogStoragesManager;
import org.apache.flink.runtime.taskexecutor.slot.NoOpSlotAllocationSnapshotPersistenceService;
@@ -58,6 +59,7 @@ public class TaskManagerServicesBuilder {
private JobLeaderService jobLeaderService;
private TaskExecutorLocalStateStoresManager taskStateManager;
private TaskExecutorStateChangelogStoragesManager taskChangelogStoragesManager;
+ private TaskExecutorChannelStateExecutorFactoryManager taskChannelStateExecutorFactoryManager;
private TaskEventDispatcher taskEventDispatcher;
private LibraryCacheManager libraryCacheManager;
private SharedResources sharedResources;
@@ -82,6 +84,8 @@ public class TaskManagerServicesBuilder {
RetryingRegistrationConfiguration.defaultConfiguration());
taskStateManager = mock(TaskExecutorLocalStateStoresManager.class);
taskChangelogStoragesManager = mock(TaskExecutorStateChangelogStoragesManager.class);
+ taskChannelStateExecutorFactoryManager =
+ new TaskExecutorChannelStateExecutorFactoryManager();
libraryCacheManager = TestingLibraryCacheManager.newBuilder().build();
managedMemorySize = MemoryManager.MIN_PAGE_SIZE;
this.slotAllocationSnapshotPersistenceService =
@@ -144,6 +148,12 @@ public class TaskManagerServicesBuilder {
return this;
}
+ public TaskManagerServicesBuilder setTaskChannelStateExecutorFactoryManager(
+ TaskExecutorChannelStateExecutorFactoryManager taskChannelStateExecutorFactoryManager) {
+ this.taskChannelStateExecutorFactoryManager = taskChannelStateExecutorFactoryManager;
+ return this;
+ }
+
public TaskManagerServicesBuilder setLibraryCacheManager(
LibraryCacheManager libraryCacheManager) {
this.libraryCacheManager = libraryCacheManager;
@@ -174,6 +184,7 @@ public class TaskManagerServicesBuilder {
jobLeaderService,
taskStateManager,
taskChangelogStoragesManager,
+ taskChannelStateExecutorFactoryManager,
taskEventDispatcher,
Executors.newSingleThreadScheduledExecutor(),
libraryCacheManager,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index 1eb55ab0c04..cfefb0459b1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -218,7 +219,8 @@ public class TaskAsyncCallTest extends TestLogger {
new TestingTaskManagerRuntimeInfo(),
taskMetricGroup,
partitionProducerStateChecker,
- executor);
+ executor,
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
}
/** Invokable for testing checkpoints. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
index a2559c5e650..127600a3cbc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TestTaskBuilder.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.blob.PermanentBlobKey;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -225,7 +226,8 @@ public final class TestTaskBuilder {
new TestingTaskManagerRuntimeInfo(taskManagerConfig),
taskMetricGroup,
partitionProducerStateChecker,
- executor);
+ executor,
+ new ChannelStateWriteRequestExecutorFactory(jobId));
}
public static void setTaskState(Task task, ExecutionState state) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
index 18c026105e5..dadac718a67 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/util/JvmExitOnFatalErrorTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.core.io.InputSplit;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.runtime.blob.VoidPermanentBlobService;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -254,7 +255,9 @@ public class JvmExitOnFatalErrorTest extends TestLogger {
tmInfo,
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
new NoOpPartitionProducerStateChecker(),
- executor);
+ executor,
+ new ChannelStateWriteRequestExecutorFactory(
+ jobInformation.getJobId()));
System.err.println("starting task thread");
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
index 7e59e0874c9..b12ce559c91 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/CheckpointConfig.java
@@ -604,6 +604,29 @@ public class CheckpointConfig implements java.io.Serializable {
ExecutionCheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT, alignedCheckpointTimeout);
}
+ /**
+ * @return the number of subtasks to share the same channel state file, as configured via {@link
+ * #setMaxSubtasksPerChannelStateFile(int)} or {@link
+ * ExecutionCheckpointingOptions#UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE}.
+ */
+ @PublicEvolving
+ public int getMaxSubtasksPerChannelStateFile() {
+ return configuration.get(
+ ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE);
+ }
+
+ /**
+ * The number of subtasks to share the same channel state file. If {@link
+ * ExecutionCheckpointingOptions#UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE} has value equal
+ * to <code>1</code>, each subtask will create a new channel state file.
+ */
+ @PublicEvolving
+ public void setMaxSubtasksPerChannelStateFile(int maxSubtasksPerChannelStateFile) {
+ configuration.set(
+ ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+ maxSubtasksPerChannelStateFile);
+ }
+
/**
* Returns whether approximate local recovery is enabled.
*
@@ -841,6 +864,10 @@ public class CheckpointConfig implements java.io.Serializable {
configuration
.getOptional(ExecutionCheckpointingOptions.ALIGNED_CHECKPOINT_TIMEOUT)
.ifPresent(this::setAlignedCheckpointTimeout);
+ configuration
+ .getOptional(
+ ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE)
+ .ifPresent(this::setMaxSubtasksPerChannelStateFile);
configuration
.getOptional(ExecutionCheckpointingOptions.FORCE_UNALIGNED)
.ifPresent(this::setForceUnalignedCheckpoints);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
index 02ebbd44e82..f13236ae049 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/environment/ExecutionCheckpointingOptions.java
@@ -280,4 +280,13 @@ public class ExecutionCheckpointingOptions {
.booleanType()
.defaultValue(false)
.withDescription("Flag to enable approximate local recovery.");
+
+ public static final ConfigOption<Integer> UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE =
+ key("execution.checkpointing.unaligned.max-subtasks-per-channel-state-file")
+ .intType()
+ .defaultValue(5)
+ .withDescription(
+ "Defines the maximum number of subtasks that share the same channel state file. "
+ + "It can reduce the number of small files when enable unaligned checkpoint. "
+ + "Each subtask will create a new channel state file when this is configured to 1.");
}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
index 476c165c8c0..267289c181f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamConfig.java
@@ -530,6 +530,17 @@ public class StreamConfig implements Serializable {
ExecutionCheckpointingOptions.MAX_CONCURRENT_CHECKPOINTS.defaultValue());
}
+ public int getMaxSubtasksPerChannelStateFile() {
+ return config.get(
+ ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE);
+ }
+
+ public void setMaxSubtasksPerChannelStateFile(int maxSubtasksPerChannelStateFile) {
+ config.set(
+ ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+ maxSubtasksPerChannelStateFile);
+ }
+
/**
* Sets the job vertex level non-chained outputs. The given output list must have the same order
* with {@link JobVertex#getProducedDataSets()}.
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index 5f96de65898..5161a2586bf 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -1009,6 +1009,7 @@ public class StreamingJobGraphGenerator {
config.setCheckpointMode(getCheckpointingMode(checkpointCfg));
config.setUnalignedCheckpointsEnabled(checkpointCfg.isUnalignedCheckpointsEnabled());
config.setAlignedCheckpointTimeout(checkpointCfg.getAlignedCheckpointTimeout());
+ config.setMaxSubtasksPerChannelStateFile(checkpointCfg.getMaxSubtasksPerChannelStateFile());
config.setMaxConcurrentCheckpoints(checkpointCfg.getMaxConcurrentCheckpoints());
for (int i = 0; i < vertex.getStatePartitioners().length; i++) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 40ad6a1fceb..ee10e7929e3 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -456,6 +456,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
this.subtaskCheckpointCoordinator =
new SubtaskCheckpointCoordinatorImpl(
+ checkpointStorage,
checkpointStorageAccess,
getName(),
actionExecutor,
@@ -471,7 +472,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
this::prepareInputSnapshot,
configuration.getMaxConcurrentCheckpoints(),
BarrierAlignmentUtil.createRegisterTimerCallback(
- mainMailboxExecutor, systemTimerService));
+ mainMailboxExecutor, systemTimerService),
+ configuration.getMaxSubtasksPerChannelStateFile());
resourceCloser.registerCloseable(subtaskCheckpointCoordinator::close);
// Register to stop all timers and threads. Should be closed first.
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
index 82b6e7f8093..c12541b7f6d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.CheckpointStateOutputStream;
import org.apache.flink.runtime.state.CheckpointStateToolset;
+import org.apache.flink.runtime.state.CheckpointStorage;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
import org.apache.flink.runtime.state.CheckpointStreamFactory;
@@ -125,7 +126,8 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
private long alignmentCheckpointId;
SubtaskCheckpointCoordinatorImpl(
- CheckpointStorageWorkerView checkpointStorage,
+ CheckpointStorage checkpointStorage,
+ CheckpointStorageWorkerView checkpointStorageView,
String taskName,
StreamTaskActionExecutor actionExecutor,
ExecutorService asyncOperationsThreadPool,
@@ -137,9 +139,10 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
ChannelStateWriter, Long, CompletableFuture<Void>, CheckpointException>
prepareInputSnapshot,
int maxRecordAbortedCheckpoints,
- DelayableTimer registerTimer) {
+ DelayableTimer registerTimer,
+ int maxSubtasksPerChannelStateFile) {
this(
- checkpointStorage,
+ checkpointStorageView,
taskName,
actionExecutor,
asyncOperationsThreadPool,
@@ -148,7 +151,8 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
prepareInputSnapshot,
maxRecordAbortedCheckpoints,
unalignedCheckpointEnabled
- ? openChannelStateWriter(taskName, checkpointStorage, env)
+ ? openChannelStateWriter(
+ taskName, checkpointStorage, env, maxSubtasksPerChannelStateFile)
: ChannelStateWriter.NO_OP,
enableCheckpointAfterTasksFinished,
registerTimer);
@@ -191,12 +195,17 @@ class SubtaskCheckpointCoordinatorImpl implements SubtaskCheckpointCoordinator {
}
private static ChannelStateWriter openChannelStateWriter(
- String taskName, CheckpointStorageWorkerView checkpointStorage, Environment env) {
- ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(
- taskName, env.getTaskInfo().getIndexOfThisSubtask(), checkpointStorage);
- writer.open();
- return writer;
+ String taskName,
+ CheckpointStorage checkpointStorage,
+ Environment env,
+ int maxSubtasksPerChannelStateFile) {
+ return new ChannelStateWriterImpl(
+ env.getJobVertexId(),
+ taskName,
+ env.getTaskInfo().getIndexOfThisSubtask(),
+ checkpointStorage,
+ env.getChannelStateExecutorFactory(),
+ maxSubtasksPerChannelStateFile);
}
@Override
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
index 5ddaaa0229e..26361ae31ed 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/InterruptSensitiveRestoreTest.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -293,7 +294,8 @@ public class InterruptSensitiveRestoreTest {
new TestingTaskManagerRuntimeInfo(),
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
mock(PartitionProducerStateChecker.class),
- mock(Executor.class));
+ mock(Executor.class),
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
}
// ------------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
index 2c1af6fb1ed..38b3e7dce02 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MockSubtaskCheckpointCoordinatorBuilder.java
@@ -22,8 +22,8 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.operators.testutils.MockEnvironment;
-import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
-import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorageAccess;
+import org.apache.flink.runtime.state.CheckpointStorage;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.runtime.taskmanager.AsyncExceptionHandler;
import org.apache.flink.util.concurrent.Executors;
import org.apache.flink.util.concurrent.FutureUtils;
@@ -38,7 +38,7 @@ import static org.apache.flink.streaming.runtime.tasks.StreamTaskActionExecutor.
/** A mock builder to build {@link SubtaskCheckpointCoordinator}. */
public class MockSubtaskCheckpointCoordinatorBuilder {
private String taskName = "mock-task";
- private CheckpointStorageWorkerView checkpointStorage;
+ private CheckpointStorage checkpointStorage;
private Environment environment;
private AsyncExceptionHandler asyncExceptionHandler;
private StreamTaskActionExecutor actionExecutor = IMMEDIATE;
@@ -47,6 +47,7 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
ChannelStateWriter, Long, CompletableFuture<Void>, CheckpointException>
prepareInputSnapshot = (channelStateWriter, aLong) -> FutureUtils.completedVoidFuture();
private boolean unalignedCheckpointEnabled;
+ private int maxSubtasksPerChannelStateFile = 5;
private int maxRecordAbortedCheckpoints = 10;
private boolean enableCheckpointAfterTasksFinished = true;
@@ -86,14 +87,18 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
return this;
}
+ public MockSubtaskCheckpointCoordinatorBuilder setMaxSubtasksPerChannelStateFile(
+ int maxSubtasksPerChannelStateFile) {
+ this.maxSubtasksPerChannelStateFile = maxSubtasksPerChannelStateFile;
+ return this;
+ }
+
SubtaskCheckpointCoordinator build() throws IOException {
if (environment == null) {
this.environment = MockEnvironment.builder().build();
}
if (checkpointStorage == null) {
- this.checkpointStorage =
- new MemoryBackendCheckpointStorageAccess(
- environment.getJobID(), null, null, 1024);
+ this.checkpointStorage = new JobManagerCheckpointStorage();
}
if (asyncExceptionHandler == null) {
this.asyncExceptionHandler = new NonHandleAsyncException();
@@ -101,6 +106,7 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
return new SubtaskCheckpointCoordinatorImpl(
checkpointStorage,
+ checkpointStorage.createCheckpointStorage(environment.getJobID()),
taskName,
actionExecutor,
executorService,
@@ -110,7 +116,8 @@ public class MockSubtaskCheckpointCoordinatorBuilder {
enableCheckpointAfterTasksFinished,
prepareInputSnapshot,
maxRecordAbortedCheckpoints,
- (callable, duration) -> () -> {});
+ (callable, duration) -> () -> {},
+ maxSubtasksPerChannelStateFile);
}
private static class NonHandleAsyncException implements AsyncExceptionHandler {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
index 9d37bc16cf8..563ee0ef10d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.execution.Environment;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider;
@@ -120,6 +121,8 @@ public class StreamMockEnvironment implements Environment {
private final boolean collectNetworkEvents;
+ private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory;
+
@Nullable private Consumer<Throwable> externalExceptionHandler;
private TaskEventDispatcher taskEventDispatcher = mock(TaskEventDispatcher.class);
@@ -196,6 +199,7 @@ public class StreamMockEnvironment implements Environment {
registry.createTaskRegistry(
jobID, executionAttemptID.getExecutionVertexId().getJobVertexId());
this.collectNetworkEvents = collectNetworkEvents;
+ this.channelStateExecutorFactory = new ChannelStateWriteRequestExecutorFactory(jobID);
}
public StreamMockEnvironment(
@@ -415,4 +419,9 @@ public class StreamMockEnvironment implements Environment {
public void setCheckpointResponder(CheckpointResponder checkpointResponder) {
this.checkpointResponder = checkpointResponder;
}
+
+ @Override
+ public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() {
+ return channelStateExecutorFactory;
+ }
}
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
index 578cd953509..bf330729830 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskSystemExitTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.security.FlinkSecurityManager;
import org.apache.flink.core.security.UserSystemExitException;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -205,7 +206,8 @@ public class StreamTaskSystemExitTest extends TestLogger {
taskManagerRuntimeInfo,
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
mock(PartitionProducerStateChecker.class),
- Executors.directExecutor());
+ Executors.directExecutor(),
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
}
/** StreamTask emulating system exit behavior from different callback functions. */
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
index a81888a5ffc..5b08e50fe35 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTerminationTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.core.testutils.OneShotLatch;
import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -189,7 +190,8 @@ public class StreamTaskTerminationTest extends TestLogger {
taskManagerRuntimeInfo,
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
mock(PartitionProducerStateChecker.class),
- Executors.directExecutor());
+ Executors.directExecutor(),
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
CompletableFuture<Void> taskRun =
CompletableFuture.runAsync(() -> task.run(), EXECUTOR_RESOURCE.getExecutor());
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
index 684568b6ce8..47354b1b6ab 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
@@ -51,6 +51,7 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.TestCheckpointStorageWorkerView;
import org.apache.flink.runtime.state.TestTaskStateManager;
+import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.streaming.api.graph.StreamConfig;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
@@ -84,7 +85,6 @@ import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
-import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
import static org.apache.flink.util.ExceptionUtils.findThrowable;
import static org.junit.Assert.assertEquals;
@@ -577,8 +577,15 @@ public class SubtaskCheckpointCoordinatorTest {
@Test
public void testChannelStateWriteResultLeakAndNotFailAfterCheckpointAborted() throws Exception {
String taskName = "test";
+ DummyEnvironment env = new DummyEnvironment();
ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
+ new ChannelStateWriterImpl(
+ env.getJobVertexId(),
+ taskName,
+ 0,
+ new JobManagerCheckpointStorage(),
+ env.getChannelStateExecutorFactory(),
+ 5);
try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
SubtaskCheckpointCoordinator coordinator =
new SubtaskCheckpointCoordinatorImpl(
@@ -586,14 +593,13 @@ public class SubtaskCheckpointCoordinatorTest {
taskName,
StreamTaskActionExecutor.IMMEDIATE,
newDirectExecutorService(),
- new DummyEnvironment(),
+ env,
(unused1, unused2) -> {},
(unused1, unused2) -> CompletableFuture.completedFuture(null),
1,
writer,
true,
(callable, duration) -> () -> {})) {
- writer.open();
final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
int checkpointId = 1;
// Abort checkpoint 1
@@ -629,23 +635,29 @@ public class SubtaskCheckpointCoordinatorTest {
CheckpointOptions unalignedOptions =
CheckpointOptions.unaligned(
CHECKPOINT, CheckpointStorageLocationReference.getDefault());
+ DummyEnvironment env = new DummyEnvironment();
+ ChannelStateWriterImpl writer =
+ new ChannelStateWriterImpl(
+ env.getJobVertexId(),
+ taskName,
+ 0,
+ new JobManagerCheckpointStorage(),
+ env.getChannelStateExecutorFactory(),
+ 5);
try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
- ChannelStateWriterImpl writer =
- new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
SubtaskCheckpointCoordinator coordinator =
new SubtaskCheckpointCoordinatorImpl(
new TestCheckpointStorageWorkerView(100),
taskName,
StreamTaskActionExecutor.IMMEDIATE,
newDirectExecutorService(),
- new DummyEnvironment(),
+ env,
(unused1, unused2) -> {},
(unused1, unused2) -> CompletableFuture.completedFuture(null),
1,
writer,
true,
(callable, duration) -> () -> {})) {
- writer.open();
final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
int checkpoint42 = 42;
int checkpoint43 = 43;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
index 0e18799404e..53ed047170c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointITCase.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.SavepointType;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -275,7 +276,8 @@ public class SynchronousCheckpointITCase {
new TestingTaskManagerRuntimeInfo(),
taskMetricGroup,
partitionProducerStateChecker,
- executor);
+ executor,
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
}
private static class TaskCleaner implements AutoCloseable {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
index f936b1f65db..bf52ca116bd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TaskCheckpointingBehaviourTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -244,7 +245,8 @@ public class TaskCheckpointingBehaviourTest extends TestLogger {
new TestingTaskManagerRuntimeInfo(),
UnregisteredMetricGroups.createUnregisteredTaskMetricGroup(),
mock(PartitionProducerStateChecker.class),
- Executors.directExecutor());
+ Executors.directExecutor(),
+ new ChannelStateWriteRequestExecutorFactory(jobInformation.getJobId()));
}
// ------------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
index 283dd913ccb..54ece459aff 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/state/ChangelogRecoveryCachingITCase.java
@@ -78,6 +78,7 @@ import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingO
import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.CHECKPOINTING_MODE;
import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.ENABLE_UNALIGNED;
import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.EXTERNALIZED_CHECKPOINT;
+import static org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions.UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE;
import static org.apache.flink.util.Preconditions.checkState;
/** Tests caching of changelog segments downloaded during recovery. */
@@ -181,6 +182,9 @@ public class ChangelogRecoveryCachingITCase extends TestLogger {
conf.set(ENABLE_UNALIGNED, true); // speedup
conf.set(ALIGNED_CHECKPOINT_TIMEOUT, Duration.ZERO); // prevent randomization
+ conf.set(
+ UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE,
+ 1); // prevent file is opened multiple times
conf.set(BUFFER_DEBLOAT_ENABLED, false); // prevent randomization
conf.set(RESTART_STRATEGY, "none"); // not expecting any failures