You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2022/11/04 08:13:23 UTC
[flink] 02/02: [FLINK-29730][checkpoint] Do not support concurrent unaligned checkpoints in the ChannelStateWriteRequestDispatcherImpl
This is an automated email from the ASF dual-hosted git repository.
pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 00a25808dfac69ba8319b9c4dc365e13fd5b87d2
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Thu Nov 3 13:52:14 2022 +0800
[FLINK-29730][checkpoint] Do not support concurrent unaligned checkpoints in the ChannelStateWriteRequestDispatcherImpl
---
.../channel/ChannelStateWriteRequest.java | 60 ++++++-----
.../ChannelStateWriteRequestDispatcherImpl.java | 115 +++++++++++++++++----
...ChannelStateWriteRequestDispatcherImplTest.java | 26 +++++
.../channel/ChannelStateWriterImplTest.java | 28 +++++
.../channel/CheckpointInProgressRequestTest.java | 3 +-
.../tasks/SubtaskCheckpointCoordinatorTest.java | 67 +++++++++++-
6 files changed, 248 insertions(+), 51 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
index 4b706045d0b..3fcd8975c61 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
@@ -52,15 +52,12 @@ interface ChannelStateWriteRequest {
static CheckpointInProgressRequest completeInput(long checkpointId) {
return new CheckpointInProgressRequest(
- "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput, false);
+ "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput);
}
static CheckpointInProgressRequest completeOutput(long checkpointId) {
return new CheckpointInProgressRequest(
- "completeOutput",
- checkpointId,
- ChannelStateCheckpointWriter::completeOutput,
- false);
+ "completeOutput", checkpointId, ChannelStateCheckpointWriter::completeOutput);
}
static ChannelStateWriteRequest write(
@@ -125,8 +122,7 @@ interface ChannelStateWriteRequest {
"Failed to recycle the output buffer of channel state.",
e);
}
- }),
- false);
+ }));
}
static ChannelStateWriteRequest buildWriteRequest(
@@ -144,8 +140,7 @@ interface ChannelStateWriteRequest {
bufferConsumer.accept(writer, buffer);
}
},
- throwable -> iterator.close(),
- false);
+ throwable -> iterator.close());
}
static void checkBufferIsBuffer(Buffer buffer) {
@@ -165,8 +160,7 @@ interface ChannelStateWriteRequest {
}
static ChannelStateWriteRequest abort(long checkpointId, Throwable cause) {
- return new CheckpointInProgressRequest(
- "abort", checkpointId, writer -> writer.fail(cause), true);
+ return new CheckpointAbortRequest(checkpointId, cause);
}
static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) {
@@ -229,29 +223,25 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
private final ThrowingConsumer<Throwable, Exception> discardAction;
private final long checkpointId;
private final String name;
- private final boolean ignoreMissingWriter;
private final AtomicReference<CheckpointInProgressRequestState> state =
new AtomicReference<>(NEW);
CheckpointInProgressRequest(
String name,
long checkpointId,
- ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
- boolean ignoreMissingWriter) {
- this(name, checkpointId, action, unused -> {}, ignoreMissingWriter);
+ ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action) {
+ this(name, checkpointId, action, unused -> {});
}
CheckpointInProgressRequest(
String name,
long checkpointId,
ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
- ThrowingConsumer<Throwable, Exception> discardAction,
- boolean ignoreMissingWriter) {
+ ThrowingConsumer<Throwable, Exception> discardAction) {
this.checkpointId = checkpointId;
this.action = checkNotNull(action);
this.discardAction = checkNotNull(discardAction);
this.name = checkNotNull(name);
- this.ignoreMissingWriter = ignoreMissingWriter;
}
@Override
@@ -277,15 +267,37 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
}
}
- void onWriterMissing() {
- if (!ignoreMissingWriter) {
- throw new IllegalArgumentException(
- "writer not found while processing request: " + toString());
- }
+ @Override
+ public String toString() {
+ return name + " " + checkpointId;
+ }
+}
+
+final class CheckpointAbortRequest implements ChannelStateWriteRequest {
+
+ private final long checkpointId;
+
+ private final Throwable throwable;
+
+ public CheckpointAbortRequest(long checkpointId, Throwable throwable) {
+ this.checkpointId = checkpointId;
+ this.throwable = throwable;
+ }
+
+ public Throwable getThrowable() {
+ return throwable;
+ }
+
+ @Override
+ public long getCheckpointId() {
+ return checkpointId;
}
+ @Override
+ public void cancel(Throwable cause) throws Exception {}
+
@Override
public String toString() {
- return name + " " + checkpointId;
+ return "Abort checkpointId-" + checkpointId;
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
index 6ab7b4bec9a..fb2e4bedc79 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
@@ -17,14 +17,14 @@
package org.apache.flink.runtime.checkpoint.channel;
+import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.util.HashMap;
-import java.util.Map;
-
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED;
+import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
@@ -36,13 +36,31 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
private static final Logger LOG =
LoggerFactory.getLogger(ChannelStateWriteRequestDispatcherImpl.class);
- private final Map<Long, ChannelStateCheckpointWriter>
- writers; // limited indirectly by results max size
private final CheckpointStorageWorkerView streamFactoryResolver;
private final ChannelStateSerializer serializer;
private final int subtaskIndex;
private final String taskName;
+ /**
+ * It is the checkpointId corresponding to writer. And It should be always update with {@link
+ * #writer}.
+ */
+ private long ongoingCheckpointId;
+
+ /**
+ * The checkpoint that checkpointId is less than or equal to maxAbortedCheckpointId should be
+ * aborted.
+ */
+ private long maxAbortedCheckpointId;
+
+ /** The aborted cause of the maxAbortedCheckpointId. */
+ private Throwable abortedCause;
+
+ /**
+ * The channelState writer of ongoing checkpointId, it can be null when the writer is finished.
+ */
+ private ChannelStateCheckpointWriter writer;
+
ChannelStateWriteRequestDispatcherImpl(
String taskName,
int subtaskIndex,
@@ -50,9 +68,10 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
ChannelStateSerializer serializer) {
this.taskName = taskName;
this.subtaskIndex = subtaskIndex;
- this.writers = new HashMap<>();
this.streamFactoryResolver = checkNotNull(streamFactoryResolver);
this.serializer = checkNotNull(serializer);
+ this.ongoingCheckpointId = -1;
+ this.maxAbortedCheckpointId = -1;
}
@Override
@@ -71,24 +90,68 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
}
private void dispatchInternal(ChannelStateWriteRequest request) throws Exception {
+ if (isAbortedCheckpoint(request.getCheckpointId())) {
+ if (request.getCheckpointId() == maxAbortedCheckpointId) {
+ request.cancel(abortedCause);
+ } else {
+ request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+ }
+ return;
+ }
+
if (request instanceof CheckpointStartRequest) {
checkState(
- !writers.containsKey(request.getCheckpointId()),
- "writer not found for request " + request);
- writers.put(request.getCheckpointId(), buildWriter((CheckpointStartRequest) request));
+ request.getCheckpointId() > ongoingCheckpointId,
+ String.format(
+ "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
+ ongoingCheckpointId, request));
+ failAndClearWriter(
+ new IllegalStateException(
+ String.format(
+ "Task[name=%s, subtaskIndex=%s] has uncompleted channelState writer of checkpointId=%s, "
+ + "but it received a new checkpoint start request of checkpointId=%s, it maybe "
+ + "a bug due to currently not supported concurrent unaligned checkpoint.",
+ taskName,
+ subtaskIndex,
+ ongoingCheckpointId,
+ request.getCheckpointId())));
+ this.writer = buildWriter((CheckpointStartRequest) request);
+ this.ongoingCheckpointId = request.getCheckpointId();
} else if (request instanceof CheckpointInProgressRequest) {
- ChannelStateCheckpointWriter writer = writers.get(request.getCheckpointId());
CheckpointInProgressRequest req = (CheckpointInProgressRequest) request;
- if (writer == null) {
- req.onWriterMissing();
- } else {
- req.execute(writer);
+ checkArgument(
+ ongoingCheckpointId == req.getCheckpointId() && writer != null,
+ "writer not found while processing request: " + req);
+ req.execute(writer);
+ } else if (request instanceof CheckpointAbortRequest) {
+ CheckpointAbortRequest req = (CheckpointAbortRequest) request;
+ if (request.getCheckpointId() > maxAbortedCheckpointId) {
+ this.maxAbortedCheckpointId = req.getCheckpointId();
+ this.abortedCause = req.getThrowable();
+ }
+
+ if (req.getCheckpointId() == ongoingCheckpointId) {
+ failAndClearWriter(req.getThrowable());
+ } else if (request.getCheckpointId() > ongoingCheckpointId) {
+ failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
}
} else {
throw new IllegalArgumentException("unknown request type: " + request);
}
}
+ private boolean isAbortedCheckpoint(long checkpointId) {
+ return checkpointId < ongoingCheckpointId || checkpointId <= maxAbortedCheckpointId;
+ }
+
+ private void failAndClearWriter(Throwable e) {
+ if (writer == null) {
+ return;
+ }
+ writer.fail(e);
+ writer = null;
+ }
+
private ChannelStateCheckpointWriter buildWriter(CheckpointStartRequest request)
throws Exception {
return new ChannelStateCheckpointWriter(
@@ -98,18 +161,26 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
streamFactoryResolver.resolveCheckpointStorageLocation(
request.getCheckpointId(), request.getLocationReference()),
serializer,
- () -> writers.remove(request.getCheckpointId()));
+ () -> {
+ checkState(
+ request.getCheckpointId() == ongoingCheckpointId,
+ "The ongoingCheckpointId[%s] was changed when clear writer of checkpoint[%s], it might be a bug.",
+ ongoingCheckpointId,
+ request.getCheckpointId());
+ this.writer = null;
+ });
}
@Override
public void fail(Throwable cause) {
- for (ChannelStateCheckpointWriter writer : writers.values()) {
- try {
- writer.fail(cause);
- } catch (Exception ex) {
- LOG.warn("unable to fail write channel state writer", cause);
- }
+ if (writer == null) {
+ return;
+ }
+ try {
+ writer.fail(cause);
+ } catch (Exception ex) {
+ LOG.warn("unable to fail write channel state writer", cause);
}
- writers.clear();
+ writer = null;
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
index 5569c6afedf..0a52f32e3e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
@@ -30,7 +30,9 @@ import org.junit.Test;
import java.util.function.Function;
+import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.util.CloseableIterator.ofElements;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
/** {@link ChannelStateWriteRequestDispatcherImpl} test. */
@@ -54,6 +56,30 @@ public class ChannelStateWriteRequestDispatcherImplTest {
1L, new ResultSubpartitionInfo(1, 2), buffers));
}
+ @Test
+ public void testConcurrentUnalignedCheckpoint() throws Exception {
+ ChannelStateWriteRequestDispatcher processor =
+ new ChannelStateWriteRequestDispatcherImpl(
+ "dummy task",
+ 0,
+ getStreamFactoryFactory(),
+ new ChannelStateSerializerImpl());
+ ChannelStateWriteResult result = new ChannelStateWriteResult();
+ processor.dispatch(
+ ChannelStateWriteRequest.start(
+ 1L, result, CheckpointStorageLocationReference.getDefault()));
+ assertFalse(result.isDone());
+
+ processor.dispatch(
+ ChannelStateWriteRequest.start(
+ 2L,
+ new ChannelStateWriteResult(),
+ CheckpointStorageLocationReference.getDefault()));
+ assertTrue(result.isDone());
+ assertTrue(result.getInputChannelStateHandles().isCompletedExceptionally());
+ assertTrue(result.getResultSubpartitionStateHandles().isCompletedExceptionally());
+ }
+
private void testBuffersRecycled(
Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
ChannelStateWriteRequestDispatcher dispatcher =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index 61c2d80aa33..0aa76af9648 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -39,6 +39,7 @@ import static org.apache.flink.util.CloseableIterator.ofElements;
import static org.apache.flink.util.ExceptionUtils.findThrowable;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
/** {@link ChannelStateWriterImpl} lifecycle tests. */
public class ChannelStateWriterImplTest {
@@ -120,6 +121,33 @@ public class ChannelStateWriterImplTest {
runWithSyncWorker(this::callAbort);
}
+ @Test
+ public void testAbortOldAndStartNewCheckpoint() throws Exception {
+ runWithSyncWorker(
+ (writer, worker) -> {
+ int checkpoint42 = 42;
+ int checkpoint43 = 43;
+ writer.start(
+ checkpoint42, CheckpointOptions.forCheckpointWithDefaultLocation());
+ writer.abort(checkpoint42, new TestException(), false);
+ writer.start(
+ checkpoint43, CheckpointOptions.forCheckpointWithDefaultLocation());
+ worker.processAllRequests();
+
+ ChannelStateWriteResult result42 = writer.getAndRemoveWriteResult(checkpoint42);
+ assertTrue(result42.isDone());
+ try {
+ result42.getInputChannelStateHandles().get();
+ fail("The result should have failed.");
+ } catch (Throwable throwable) {
+ assertTrue(findThrowable(throwable, TestException.class).isPresent());
+ }
+
+ ChannelStateWriteResult result43 = writer.getAndRemoveWriteResult(checkpoint43);
+ assertFalse(result43.isDone());
+ });
+ }
+
@Test(expected = TestException.class)
public void testBuffersRecycledOnError() throws Exception {
unwrappingError(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index 7d7c291043a..e37fbcb1404 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -67,8 +67,7 @@ public class CheckpointInProgressRequestTest {
unused -> {
cancelCounter.incrementAndGet();
await(cb);
- },
- false);
+ });
}
private void await(CyclicBarrier cb) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
index 2ddd9083a7a..bcec491d82a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
@@ -72,6 +72,7 @@ import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Map;
+import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
@@ -84,6 +85,7 @@ import java.util.function.Supplier;
import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
+import static org.apache.flink.util.ExceptionUtils.findThrowable;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
@@ -604,9 +606,6 @@ public class SubtaskCheckpointCoordinatorTest {
ChannelStateWriter.ChannelStateWriteResult writeResult =
writer.getWriteResult(checkpointId);
assertNotNull(writeResult);
- assertFalse(writeResult.isDone());
- assertFalse(writeResult.getInputChannelStateHandles().isCompletedExceptionally());
- assertFalse(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally());
coordinator.checkpointState(
new CheckpointMetaData(checkpointId, System.currentTimeMillis()),
@@ -623,6 +622,68 @@ public class SubtaskCheckpointCoordinatorTest {
}
}
+ @Test
+ public void testAbortOldAndStartNewCheckpoint() throws Exception {
+ String taskName = "test";
+ CheckpointOptions unalignedOptions =
+ CheckpointOptions.unaligned(
+ CHECKPOINT, CheckpointStorageLocationReference.getDefault());
+ try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
+ ChannelStateWriterImpl writer =
+ new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
+ SubtaskCheckpointCoordinator coordinator =
+ new SubtaskCheckpointCoordinatorImpl(
+ new TestCheckpointStorageWorkerView(100),
+ taskName,
+ StreamTaskActionExecutor.IMMEDIATE,
+ newDirectExecutorService(),
+ new DummyEnvironment(),
+ (unused1, unused2) -> {},
+ (unused1, unused2) -> CompletableFuture.completedFuture(null),
+ 1,
+ writer,
+ true,
+ (callable, duration) -> () -> {})) {
+ writer.open();
+ final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
+ int checkpoint42 = 42;
+ int checkpoint43 = 43;
+
+ coordinator.initInputsCheckpoint(checkpoint42, unalignedOptions);
+ ChannelStateWriter.ChannelStateWriteResult result42 =
+ writer.getWriteResult(checkpoint42);
+ assertNotNull(result42);
+
+ coordinator.notifyCheckpointAborted(checkpoint42, operatorChain, () -> true);
+ coordinator.initInputsCheckpoint(checkpoint43, unalignedOptions);
+ ChannelStateWriter.ChannelStateWriteResult result43 =
+ writer.getWriteResult(checkpoint43);
+
+ result42.waitForDone();
+ assertTrue(result42.isDone());
+ try {
+ result42.getInputChannelStateHandles().get();
+ fail("The result should have failed.");
+ } catch (Throwable throwable) {
+ assertTrue(findThrowable(throwable, CancellationException.class).isPresent());
+ }
+
+ // test the new checkpoint can be completed
+ coordinator.checkpointState(
+ new CheckpointMetaData(checkpoint43, System.currentTimeMillis()),
+ unalignedOptions,
+ new CheckpointMetricsBuilder(),
+ operatorChain,
+ false,
+ () -> true);
+ result43.waitForDone();
+ assertNotNull(result43);
+ assertTrue(result43.isDone());
+ assertFalse(result43.getInputChannelStateHandles().isCompletedExceptionally());
+ assertFalse(result43.getResultSubpartitionStateHandles().isCompletedExceptionally());
+ }
+ }
+
private OperatorChain<?, ?> getOperatorChain(MockEnvironment mockEnvironment) throws Exception {
return new RegularOperatorChain<>(
new MockStreamTaskBuilder(mockEnvironment).build(), new NonRecordWriter<>());