You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by pn...@apache.org on 2022/11/04 08:13:23 UTC

[flink] 02/02: [FLINK-29730][checkpoint] Do not support concurrent unaligned checkpoints in the ChannelStateWriteRequestDispatcherImpl

This is an automated email from the ASF dual-hosted git repository.

pnowojski pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 00a25808dfac69ba8319b9c4dc365e13fd5b87d2
Author: 1996fanrui <19...@gmail.com>
AuthorDate: Thu Nov 3 13:52:14 2022 +0800

    [FLINK-29730][checkpoint] Do not support concurrent unaligned checkpoints in the ChannelStateWriteRequestDispatcherImpl
---
 .../channel/ChannelStateWriteRequest.java          |  60 ++++++-----
 .../ChannelStateWriteRequestDispatcherImpl.java    | 115 +++++++++++++++++----
 ...ChannelStateWriteRequestDispatcherImplTest.java |  26 +++++
 .../channel/ChannelStateWriterImplTest.java        |  28 +++++
 .../channel/CheckpointInProgressRequestTest.java   |   3 +-
 .../tasks/SubtaskCheckpointCoordinatorTest.java    |  67 +++++++++++-
 6 files changed, 248 insertions(+), 51 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
index 4b706045d0b..3fcd8975c61 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequest.java
@@ -52,15 +52,12 @@ interface ChannelStateWriteRequest {
 
     static CheckpointInProgressRequest completeInput(long checkpointId) {
         return new CheckpointInProgressRequest(
-                "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput, false);
+                "completeInput", checkpointId, ChannelStateCheckpointWriter::completeInput);
     }
 
     static CheckpointInProgressRequest completeOutput(long checkpointId) {
         return new CheckpointInProgressRequest(
-                "completeOutput",
-                checkpointId,
-                ChannelStateCheckpointWriter::completeOutput,
-                false);
+                "completeOutput", checkpointId, ChannelStateCheckpointWriter::completeOutput);
     }
 
     static ChannelStateWriteRequest write(
@@ -125,8 +122,7 @@ interface ChannelStateWriteRequest {
                                                 "Failed to recycle the output buffer of channel state.",
                                                 e);
                                     }
-                                }),
-                false);
+                                }));
     }
 
     static ChannelStateWriteRequest buildWriteRequest(
@@ -144,8 +140,7 @@ interface ChannelStateWriteRequest {
                         bufferConsumer.accept(writer, buffer);
                     }
                 },
-                throwable -> iterator.close(),
-                false);
+                throwable -> iterator.close());
     }
 
     static void checkBufferIsBuffer(Buffer buffer) {
@@ -165,8 +160,7 @@ interface ChannelStateWriteRequest {
     }
 
     static ChannelStateWriteRequest abort(long checkpointId, Throwable cause) {
-        return new CheckpointInProgressRequest(
-                "abort", checkpointId, writer -> writer.fail(cause), true);
+        return new CheckpointAbortRequest(checkpointId, cause);
     }
 
     static ThrowingConsumer<Throwable, Exception> recycle(Buffer[] flinkBuffers) {
@@ -229,29 +223,25 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
     private final ThrowingConsumer<Throwable, Exception> discardAction;
     private final long checkpointId;
     private final String name;
-    private final boolean ignoreMissingWriter;
     private final AtomicReference<CheckpointInProgressRequestState> state =
             new AtomicReference<>(NEW);
 
     CheckpointInProgressRequest(
             String name,
             long checkpointId,
-            ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
-            boolean ignoreMissingWriter) {
-        this(name, checkpointId, action, unused -> {}, ignoreMissingWriter);
+            ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action) {
+        this(name, checkpointId, action, unused -> {});
     }
 
     CheckpointInProgressRequest(
             String name,
             long checkpointId,
             ThrowingConsumer<ChannelStateCheckpointWriter, Exception> action,
-            ThrowingConsumer<Throwable, Exception> discardAction,
-            boolean ignoreMissingWriter) {
+            ThrowingConsumer<Throwable, Exception> discardAction) {
         this.checkpointId = checkpointId;
         this.action = checkNotNull(action);
         this.discardAction = checkNotNull(discardAction);
         this.name = checkNotNull(name);
-        this.ignoreMissingWriter = ignoreMissingWriter;
     }
 
     @Override
@@ -277,15 +267,37 @@ final class CheckpointInProgressRequest implements ChannelStateWriteRequest {
         }
     }
 
-    void onWriterMissing() {
-        if (!ignoreMissingWriter) {
-            throw new IllegalArgumentException(
-                    "writer not found while processing request: " + toString());
-        }
+    @Override
+    public String toString() {
+        return name + " " + checkpointId;
+    }
+}
+
+final class CheckpointAbortRequest implements ChannelStateWriteRequest {
+
+    private final long checkpointId;
+
+    private final Throwable throwable;
+
+    public CheckpointAbortRequest(long checkpointId, Throwable throwable) {
+        this.checkpointId = checkpointId;
+        this.throwable = throwable;
+    }
+
+    public Throwable getThrowable() {
+        return throwable;
+    }
+
+    @Override
+    public long getCheckpointId() {
+        return checkpointId;
     }
 
+    @Override
+    public void cancel(Throwable cause) throws Exception {}
+
     @Override
     public String toString() {
-        return name + " " + checkpointId;
+        return "Abort checkpointId-" + checkpointId;
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
index 6ab7b4bec9a..fb2e4bedc79 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImpl.java
@@ -17,14 +17,14 @@
 
 package org.apache.flink.runtime.checkpoint.channel;
 
+import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.state.CheckpointStorageWorkerView;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.HashMap;
-import java.util.Map;
-
+import static org.apache.flink.runtime.checkpoint.CheckpointFailureReason.CHECKPOINT_DECLINED_SUBSUMED;
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -36,13 +36,31 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
     private static final Logger LOG =
             LoggerFactory.getLogger(ChannelStateWriteRequestDispatcherImpl.class);
 
-    private final Map<Long, ChannelStateCheckpointWriter>
-            writers; // limited indirectly by results max size
     private final CheckpointStorageWorkerView streamFactoryResolver;
     private final ChannelStateSerializer serializer;
     private final int subtaskIndex;
     private final String taskName;
 
+    /**
+     * It is the checkpointId corresponding to writer. And It should be always update with {@link
+     * #writer}.
+     */
+    private long ongoingCheckpointId;
+
+    /**
+     * The checkpoint that checkpointId is less than or equal to maxAbortedCheckpointId should be
+     * aborted.
+     */
+    private long maxAbortedCheckpointId;
+
+    /** The aborted cause of the maxAbortedCheckpointId. */
+    private Throwable abortedCause;
+
+    /**
+     * The channelState writer of ongoing checkpointId, it can be null when the writer is finished.
+     */
+    private ChannelStateCheckpointWriter writer;
+
     ChannelStateWriteRequestDispatcherImpl(
             String taskName,
             int subtaskIndex,
@@ -50,9 +68,10 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
             ChannelStateSerializer serializer) {
         this.taskName = taskName;
         this.subtaskIndex = subtaskIndex;
-        this.writers = new HashMap<>();
         this.streamFactoryResolver = checkNotNull(streamFactoryResolver);
         this.serializer = checkNotNull(serializer);
+        this.ongoingCheckpointId = -1;
+        this.maxAbortedCheckpointId = -1;
     }
 
     @Override
@@ -71,24 +90,68 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
     }
 
     private void dispatchInternal(ChannelStateWriteRequest request) throws Exception {
+        if (isAbortedCheckpoint(request.getCheckpointId())) {
+            if (request.getCheckpointId() == maxAbortedCheckpointId) {
+                request.cancel(abortedCause);
+            } else {
+                request.cancel(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
+            }
+            return;
+        }
+
         if (request instanceof CheckpointStartRequest) {
             checkState(
-                    !writers.containsKey(request.getCheckpointId()),
-                    "writer not found for request " + request);
-            writers.put(request.getCheckpointId(), buildWriter((CheckpointStartRequest) request));
+                    request.getCheckpointId() > ongoingCheckpointId,
+                    String.format(
+                            "Checkpoint must be incremented, ongoingCheckpointId is %s, but the request is %s.",
+                            ongoingCheckpointId, request));
+            failAndClearWriter(
+                    new IllegalStateException(
+                            String.format(
+                                    "Task[name=%s, subtaskIndex=%s] has uncompleted channelState writer of checkpointId=%s, "
+                                            + "but it received a new checkpoint start request of checkpointId=%s, it maybe "
+                                            + "a bug due to currently not supported concurrent unaligned checkpoint.",
+                                    taskName,
+                                    subtaskIndex,
+                                    ongoingCheckpointId,
+                                    request.getCheckpointId())));
+            this.writer = buildWriter((CheckpointStartRequest) request);
+            this.ongoingCheckpointId = request.getCheckpointId();
         } else if (request instanceof CheckpointInProgressRequest) {
-            ChannelStateCheckpointWriter writer = writers.get(request.getCheckpointId());
             CheckpointInProgressRequest req = (CheckpointInProgressRequest) request;
-            if (writer == null) {
-                req.onWriterMissing();
-            } else {
-                req.execute(writer);
+            checkArgument(
+                    ongoingCheckpointId == req.getCheckpointId() && writer != null,
+                    "writer not found while processing request: " + req);
+            req.execute(writer);
+        } else if (request instanceof CheckpointAbortRequest) {
+            CheckpointAbortRequest req = (CheckpointAbortRequest) request;
+            if (request.getCheckpointId() > maxAbortedCheckpointId) {
+                this.maxAbortedCheckpointId = req.getCheckpointId();
+                this.abortedCause = req.getThrowable();
+            }
+
+            if (req.getCheckpointId() == ongoingCheckpointId) {
+                failAndClearWriter(req.getThrowable());
+            } else if (request.getCheckpointId() > ongoingCheckpointId) {
+                failAndClearWriter(new CheckpointException(CHECKPOINT_DECLINED_SUBSUMED));
             }
         } else {
             throw new IllegalArgumentException("unknown request type: " + request);
         }
     }
 
+    private boolean isAbortedCheckpoint(long checkpointId) {
+        return checkpointId < ongoingCheckpointId || checkpointId <= maxAbortedCheckpointId;
+    }
+
+    private void failAndClearWriter(Throwable e) {
+        if (writer == null) {
+            return;
+        }
+        writer.fail(e);
+        writer = null;
+    }
+
     private ChannelStateCheckpointWriter buildWriter(CheckpointStartRequest request)
             throws Exception {
         return new ChannelStateCheckpointWriter(
@@ -98,18 +161,26 @@ final class ChannelStateWriteRequestDispatcherImpl implements ChannelStateWriteR
                 streamFactoryResolver.resolveCheckpointStorageLocation(
                         request.getCheckpointId(), request.getLocationReference()),
                 serializer,
-                () -> writers.remove(request.getCheckpointId()));
+                () -> {
+                    checkState(
+                            request.getCheckpointId() == ongoingCheckpointId,
+                            "The ongoingCheckpointId[%s] was changed when clear writer of checkpoint[%s], it might be a bug.",
+                            ongoingCheckpointId,
+                            request.getCheckpointId());
+                    this.writer = null;
+                });
     }
 
     @Override
     public void fail(Throwable cause) {
-        for (ChannelStateCheckpointWriter writer : writers.values()) {
-            try {
-                writer.fail(cause);
-            } catch (Exception ex) {
-                LOG.warn("unable to fail write channel state writer", cause);
-            }
+        if (writer == null) {
+            return;
+        }
+        try {
+            writer.fail(cause);
+        } catch (Exception ex) {
+            LOG.warn("unable to fail write channel state writer", cause);
         }
-        writers.clear();
+        writer = null;
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
index 5569c6afedf..0a52f32e3e0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherImplTest.java
@@ -30,7 +30,9 @@ import org.junit.Test;
 
 import java.util.function.Function;
 
+import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.util.CloseableIterator.ofElements;
+import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 /** {@link ChannelStateWriteRequestDispatcherImpl} test. */
@@ -54,6 +56,30 @@ public class ChannelStateWriteRequestDispatcherImplTest {
                                 1L, new ResultSubpartitionInfo(1, 2), buffers));
     }
 
+    @Test
+    public void testConcurrentUnalignedCheckpoint() throws Exception {
+        ChannelStateWriteRequestDispatcher processor =
+                new ChannelStateWriteRequestDispatcherImpl(
+                        "dummy task",
+                        0,
+                        getStreamFactoryFactory(),
+                        new ChannelStateSerializerImpl());
+        ChannelStateWriteResult result = new ChannelStateWriteResult();
+        processor.dispatch(
+                ChannelStateWriteRequest.start(
+                        1L, result, CheckpointStorageLocationReference.getDefault()));
+        assertFalse(result.isDone());
+
+        processor.dispatch(
+                ChannelStateWriteRequest.start(
+                        2L,
+                        new ChannelStateWriteResult(),
+                        CheckpointStorageLocationReference.getDefault()));
+        assertTrue(result.isDone());
+        assertTrue(result.getInputChannelStateHandles().isCompletedExceptionally());
+        assertTrue(result.getResultSubpartitionStateHandles().isCompletedExceptionally());
+    }
+
     private void testBuffersRecycled(
             Function<NetworkBuffer[], ChannelStateWriteRequest> requestBuilder) throws Exception {
         ChannelStateWriteRequestDispatcher dispatcher =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
index 61c2d80aa33..0aa76af9648 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateWriterImplTest.java
@@ -39,6 +39,7 @@ import static org.apache.flink.util.CloseableIterator.ofElements;
 import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 
 /** {@link ChannelStateWriterImpl} lifecycle tests. */
 public class ChannelStateWriterImplTest {
@@ -120,6 +121,33 @@ public class ChannelStateWriterImplTest {
         runWithSyncWorker(this::callAbort);
     }
 
+    @Test
+    public void testAbortOldAndStartNewCheckpoint() throws Exception {
+        runWithSyncWorker(
+                (writer, worker) -> {
+                    int checkpoint42 = 42;
+                    int checkpoint43 = 43;
+                    writer.start(
+                            checkpoint42, CheckpointOptions.forCheckpointWithDefaultLocation());
+                    writer.abort(checkpoint42, new TestException(), false);
+                    writer.start(
+                            checkpoint43, CheckpointOptions.forCheckpointWithDefaultLocation());
+                    worker.processAllRequests();
+
+                    ChannelStateWriteResult result42 = writer.getAndRemoveWriteResult(checkpoint42);
+                    assertTrue(result42.isDone());
+                    try {
+                        result42.getInputChannelStateHandles().get();
+                        fail("The result should have failed.");
+                    } catch (Throwable throwable) {
+                        assertTrue(findThrowable(throwable, TestException.class).isPresent());
+                    }
+
+                    ChannelStateWriteResult result43 = writer.getAndRemoveWriteResult(checkpoint43);
+                    assertFalse(result43.isDone());
+                });
+    }
+
     @Test(expected = TestException.class)
     public void testBuffersRecycledOnError() throws Exception {
         unwrappingError(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
index 7d7c291043a..e37fbcb1404 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/channel/CheckpointInProgressRequestTest.java
@@ -67,8 +67,7 @@ public class CheckpointInProgressRequestTest {
                 unused -> {
                     cancelCounter.incrementAndGet();
                     await(cb);
-                },
-                false);
+                });
     }
 
     private void await(CyclicBarrier cb) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
index 2ddd9083a7a..bcec491d82a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorTest.java
@@ -72,6 +72,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Map;
+import java.util.concurrent.CancellationException;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutionException;
@@ -84,6 +85,7 @@ import java.util.function.Supplier;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.CHECKPOINT;
 import static org.apache.flink.runtime.state.ChannelPersistenceITCase.getStreamFactoryFactory;
 import static org.apache.flink.shaded.guava30.com.google.common.util.concurrent.MoreExecutors.newDirectExecutorService;
+import static org.apache.flink.util.ExceptionUtils.findThrowable;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertNotNull;
@@ -604,9 +606,6 @@ public class SubtaskCheckpointCoordinatorTest {
             ChannelStateWriter.ChannelStateWriteResult writeResult =
                     writer.getWriteResult(checkpointId);
             assertNotNull(writeResult);
-            assertFalse(writeResult.isDone());
-            assertFalse(writeResult.getInputChannelStateHandles().isCompletedExceptionally());
-            assertFalse(writeResult.getResultSubpartitionStateHandles().isCompletedExceptionally());
 
             coordinator.checkpointState(
                     new CheckpointMetaData(checkpointId, System.currentTimeMillis()),
@@ -623,6 +622,68 @@ public class SubtaskCheckpointCoordinatorTest {
         }
     }
 
+    @Test
+    public void testAbortOldAndStartNewCheckpoint() throws Exception {
+        String taskName = "test";
+        CheckpointOptions unalignedOptions =
+                CheckpointOptions.unaligned(
+                        CHECKPOINT, CheckpointStorageLocationReference.getDefault());
+        try (MockEnvironment mockEnvironment = MockEnvironment.builder().build();
+                ChannelStateWriterImpl writer =
+                        new ChannelStateWriterImpl(taskName, 0, getStreamFactoryFactory());
+                SubtaskCheckpointCoordinator coordinator =
+                        new SubtaskCheckpointCoordinatorImpl(
+                                new TestCheckpointStorageWorkerView(100),
+                                taskName,
+                                StreamTaskActionExecutor.IMMEDIATE,
+                                newDirectExecutorService(),
+                                new DummyEnvironment(),
+                                (unused1, unused2) -> {},
+                                (unused1, unused2) -> CompletableFuture.completedFuture(null),
+                                1,
+                                writer,
+                                true,
+                                (callable, duration) -> () -> {})) {
+            writer.open();
+            final OperatorChain<?, ?> operatorChain = getOperatorChain(mockEnvironment);
+            int checkpoint42 = 42;
+            int checkpoint43 = 43;
+
+            coordinator.initInputsCheckpoint(checkpoint42, unalignedOptions);
+            ChannelStateWriter.ChannelStateWriteResult result42 =
+                    writer.getWriteResult(checkpoint42);
+            assertNotNull(result42);
+
+            coordinator.notifyCheckpointAborted(checkpoint42, operatorChain, () -> true);
+            coordinator.initInputsCheckpoint(checkpoint43, unalignedOptions);
+            ChannelStateWriter.ChannelStateWriteResult result43 =
+                    writer.getWriteResult(checkpoint43);
+
+            result42.waitForDone();
+            assertTrue(result42.isDone());
+            try {
+                result42.getInputChannelStateHandles().get();
+                fail("The result should have failed.");
+            } catch (Throwable throwable) {
+                assertTrue(findThrowable(throwable, CancellationException.class).isPresent());
+            }
+
+            // test the new checkpoint can be completed
+            coordinator.checkpointState(
+                    new CheckpointMetaData(checkpoint43, System.currentTimeMillis()),
+                    unalignedOptions,
+                    new CheckpointMetricsBuilder(),
+                    operatorChain,
+                    false,
+                    () -> true);
+            result43.waitForDone();
+            assertNotNull(result43);
+            assertTrue(result43.isDone());
+            assertFalse(result43.getInputChannelStateHandles().isCompletedExceptionally());
+            assertFalse(result43.getResultSubpartitionStateHandles().isCompletedExceptionally());
+        }
+    }
+
     private OperatorChain<?, ?> getOperatorChain(MockEnvironment mockEnvironment) throws Exception {
         return new RegularOperatorChain<>(
                 new MockStreamTaskBuilder(mockEnvironment).build(), new NonRecordWriter<>());