You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:41 UTC

[flink-ml] 03/08: [FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 5896237262fc1327ab465dbbaeceaa8402c23c09
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Oct 4 21:20:49 2021 +0800

    [FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint
---
 .../flink/iteration/operator/HeadOperator.java     |  63 +++++++++--
 .../operator/HeadOperatorCheckpointAligner.java    | 104 +++++++++++++++++
 .../coordinator/HeadOperatorCoordinator.java       |  36 +++---
 .../coordinator/SharedProgressAligner.java         | 107 ++++++++++++++----
 .../coordinator/SharedProgressAlignerListener.java |  30 +++++
 .../operator/event/CoordinatorCheckpointEvent.java |  59 ++++++++++
 .../flink/iteration/operator/HeadOperatorTest.java | 117 +++++++++++++++++++-
 .../coordinator/SharedProgressAlignerTest.java     | 123 ++++++++++++++-------
 8 files changed, 554 insertions(+), 85 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index d7a9a54..6e5a7b4 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -23,9 +23,11 @@ import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
@@ -36,6 +38,7 @@ import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
 import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
@@ -49,6 +52,7 @@ import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.OutputTag;
 
@@ -59,8 +63,23 @@ import java.util.concurrent.Executor;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
- * The head operator unions the initialized variable stream and the feedback stream, and synchronize
- * the epoch watermark (round).
+ * The head operator unions the initialized variable stream and the feedback stream, synchronize the
+ * epoch watermark (round) and taking care of the checkpoints.
+ *
+ * <p>Specially for checkpoint, the head operator would like to
+ *
+ * <ul>
+ *   <li>Ensures the exactly-once for processing elements.
+ *   <li>Ensures the exactly-once for {@link IterationListener#onEpochWatermarkIncremented(int,
+ *       IterationListener.Context, Collector)}.
+ * </ul>
+ *
+ * <p>To implement the first target, the head operator also need to include the records between
+ * alignment and received barrier from the feed-back edge into the snapshot. To implement the second
+ * target, the head operator would also wait for the notification from the OperatorCoordinator in
+ * additional to the task inputs. This ensures the {@link GloballyAlignedEvent} would not interleave
+ * with the epoch watermarks and all the tasks inside the iteration would be notified with the same
+ * epochs, which facility the rescaling in the future.
  */
 public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         implements OneInputStreamOperator<IterationRecord<?>, IterationRecord<?>>,
@@ -91,6 +110,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     private HeadOperatorRecordProcessor recordProcessor;
 
+    private HeadOperatorCheckpointAligner checkpointAligner;
+
     public HeadOperator(
             IterationID iterationId,
             int feedbackIndex,
@@ -129,6 +150,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         status = HeadOperatorStatus.RUNNING;
         recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
 
+        checkpointAligner = new HeadOperatorCheckpointAligner();
+
         // Here we register a mail
         registerFeedbackConsumer(
                 (Runnable runnable) -> {
@@ -139,6 +162,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     }
 
     @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        super.prepareSnapshotPreBarrier(checkpointId);
+
+        checkpointAligner.waitTillCoordinatorNotified(checkpointId, mailboxExecutor::yield);
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        checkpointAligner
+                .onStateSnapshot(context.getCheckpointId())
+                .forEach(this::processGloballyAlignedEvent);
+    }
+
+    @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
         recordProcessor.processElement(element);
     }
@@ -155,12 +193,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     @Override
     public void handleOperatorEvent(OperatorEvent operatorEvent) {
         if (operatorEvent instanceof GloballyAlignedEvent) {
-            boolean shouldTerminate =
-                    recordProcessor.onGloballyAligned((GloballyAlignedEvent) operatorEvent);
-            if (shouldTerminate) {
-                status = HeadOperatorStatus.TERMINATING;
-                recordProcessor = new TerminatingHeadOperatorRecordProcessor();
-            }
+            checkpointAligner
+                    .checkHoldingGloballyAlignedEvent((GloballyAlignedEvent) operatorEvent)
+                    .ifPresent(this::processGloballyAlignedEvent);
+        } else if (operatorEvent instanceof CoordinatorCheckpointEvent) {
+            checkpointAligner.coordinatorNotify((CoordinatorCheckpointEvent) operatorEvent);
+        } else {
+            throw new FlinkRuntimeException("Unsupported operator event: " + operatorEvent);
+        }
+    }
+
+    private void processGloballyAlignedEvent(GloballyAlignedEvent globallyAlignedEvent) {
+        boolean shouldTerminate = recordProcessor.onGloballyAligned(globallyAlignedEvent);
+        if (shouldTerminate) {
+            status = HeadOperatorStatus.TERMINATING;
+            recordProcessor = new TerminatingHeadOperatorRecordProcessor();
         }
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
new file mode 100644
index 0000000..b4f4215
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.util.function.RunnableWithException;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.TreeMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Aligns the checkpoint barrier from the task inputs and the checkpoint event from the coordinator.
+ * Besides, it needs to hold the other operator events after the checkpoint event till the state is
+ * snapshot.
+ */
+class HeadOperatorCheckpointAligner {
+
+    private final TreeMap<Long, CheckpointAlignment> checkpointAlignmments;
+
+    private long latestCheckpointFromCoordinator;
+
+    HeadOperatorCheckpointAligner() {
+        this.checkpointAlignmments = new TreeMap<>();
+    }
+
+    void waitTillCoordinatorNotified(long checkpointId, RunnableWithException defaultAction)
+            throws Exception {
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.computeIfAbsent(
+                        checkpointId, ignored -> new CheckpointAlignment(true, false));
+        while (!checkpointAlignment.notifiedFromCoordinator) {
+            defaultAction.run();
+        }
+        checkpointAlignment.notifiedFromChannels = true;
+    }
+
+    void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
+        checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
+        latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.computeIfAbsent(
+                        checkpointEvent.getCheckpointId(),
+                        ignored -> new CheckpointAlignment(false, true));
+        checkpointAlignment.notifiedFromCoordinator = true;
+    }
+
+    Optional<GloballyAlignedEvent> checkHoldingGloballyAlignedEvent(
+            GloballyAlignedEvent globallyAlignedEvent) {
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.get(latestCheckpointFromCoordinator);
+        if (checkpointAlignment != null && !checkpointAlignment.notifiedFromChannels) {
+            checkpointAlignment.pendingGlobalEvents.add(globallyAlignedEvent);
+            return Optional.empty();
+        }
+
+        return Optional.of(globallyAlignedEvent);
+    }
+
+    List<GloballyAlignedEvent> onStateSnapshot(long checkpointId) {
+        CheckpointAlignment checkpointAlignment = checkpointAlignmments.remove(checkpointId);
+        checkState(
+                checkpointAlignment.notifiedFromCoordinator
+                        && checkpointAlignment.notifiedFromChannels,
+                "Checkpoint " + checkpointId + " is not fully aligned");
+        return checkpointAlignment.pendingGlobalEvents;
+    }
+
+    private static class CheckpointAlignment {
+
+        final List<GloballyAlignedEvent> pendingGlobalEvents;
+
+        boolean notifiedFromChannels;
+
+        boolean notifiedFromCoordinator;
+
+        public CheckpointAlignment(boolean notifiedFromChannels, boolean notifiedFromCoordinator) {
+            this.pendingGlobalEvents = new ArrayList<>();
+
+            this.notifiedFromChannels = notifiedFromChannels;
+            this.notifiedFromCoordinator = notifiedFromCoordinator;
+        }
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index bde35a2..2c01c33 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.operator.HeadOperator;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -37,7 +38,7 @@ import java.util.concurrent.Executors;
  * SharedProgressAligner} when received aligned event from the operator, and emit the globally
  * aligned event back after one round is globally aligned.
  */
-public class HeadOperatorCoordinator implements OperatorCoordinator {
+public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgressAlignerListener {
 
     private final Context context;
 
@@ -50,18 +51,24 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
         this.sharedProgressAligner = Objects.requireNonNull(sharedProgressAligner);
         this.subtaskGateways = new SubtaskGateway[context.currentParallelism()];
 
-        sharedProgressAligner.registerAlignedConsumer(context.getOperatorId(), this::onAligned);
+        sharedProgressAligner.registerAlignedListener(context.getOperatorId(), this);
     }
 
     @Override
     public void start() {}
 
     @Override
-    public void subtaskReady(int i, SubtaskGateway subtaskGateway) {
-        this.subtaskGateways[i] = subtaskGateway;
+    public void subtaskReady(int subtaskIndex, SubtaskGateway subtaskGateway) {
+        this.subtaskGateways[subtaskIndex] = subtaskGateway;
     }
 
     @Override
+    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+
+    @Override
+    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+
+    @Override
     public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
         if (operatorEvent instanceof SubtaskAlignedEvent) {
             sharedProgressAligner.reportSubtaskProgress(
@@ -71,6 +78,11 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
         }
     }
 
+    @Override
+    public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
+        sharedProgressAligner.requestCheckpoint(l, context.currentParallelism(), completableFuture);
+    }
+
     public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
         for (int i = 0; i < context.currentParallelism(); ++i) {
             subtaskGateways[i].sendEvent(globallyAlignedEvent);
@@ -78,25 +90,21 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
     }
 
     @Override
-    public void close() {
-        sharedProgressAligner.unregisterConsumer(context.getOperatorId());
+    public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+        for (int i = 0; i < context.currentParallelism(); ++i) {
+            subtaskGateways[i].sendEvent(coordinatorCheckpointEvent);
+        }
     }
 
     @Override
-    public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
-        completableFuture.complete(new byte[0]);
+    public void close() {
+        sharedProgressAligner.unregisterListener(context.getOperatorId());
     }
 
     @Override
     public void notifyCheckpointComplete(long l) {}
 
     @Override
-    public void resetToCheckpoint(long l, @Nullable byte[] bytes) {}
-
-    @Override
-    public void subtaskFailed(int i, @Nullable Throwable throwable) {}
-
-    @Override
     public void subtaskReset(int i, long l) {}
 
     /** The factory of {@link HeadOperatorCoordinator}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 6d03a5e..9c2fd43 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -31,12 +32,14 @@ import org.apache.flink.util.function.ThrowingRunnable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
-import java.util.function.Consumer;
 import java.util.function.Supplier;
 
 import static org.apache.flink.util.Preconditions.checkState;
@@ -44,7 +47,7 @@ import static org.apache.flink.util.Preconditions.checkState;
 /**
  * The progress aligner shared between multiple {@link HeadOperatorCoordinator}. It maintains the
  * information for each round, once one round is aligned, it would notify all the register
- * consumers.
+ * listeners.
  */
 public class SharedProgressAligner {
 
@@ -63,7 +66,9 @@ public class SharedProgressAligner {
 
     private final Map<Integer, EpochStatus> statusByEpoch;
 
-    private final Map<OperatorID, Consumer<GloballyAlignedEvent>> alignedConsumers;
+    private final Map<OperatorID, SharedProgressAlignerListener> listeners;
+
+    private final Map<Long, CheckpointStatus> checkpointStatuses;
 
     public static SharedProgressAligner getOrCreate(
             IterationID iterationId,
@@ -93,29 +98,28 @@ public class SharedProgressAligner {
         this.executor = Objects.requireNonNull(executor);
 
         this.statusByEpoch = new HashMap<>();
-        this.alignedConsumers = new HashMap<>();
+        this.listeners = new HashMap<>();
+        this.checkpointStatuses = new HashMap<>();
     }
 
-    public void registerAlignedConsumer(
-            OperatorID operatorID, Consumer<GloballyAlignedEvent> alignedConsumer) {
+    public void registerAlignedListener(
+            OperatorID operatorID, SharedProgressAlignerListener alignedConsumer) {
         runInEventLoop(
-                () -> this.alignedConsumers.put(operatorID, alignedConsumer),
-                "Register consumer %s",
+                () -> this.listeners.put(operatorID, alignedConsumer),
+                "Register listeners %s",
                 operatorID.toHexString());
     }
 
-    public void unregisterConsumer(OperatorID operatorID) {
-        synchronized (this) {
-            runInEventLoop(
-                    () -> {
-                        this.alignedConsumers.remove(operatorID);
-                        if (alignedConsumers.isEmpty()) {
-                            instances.remove(iterationId);
-                        }
-                    },
-                    "Unregister consumer %s",
-                    operatorID.toHexString());
-        }
+    public void unregisterListener(OperatorID operatorID) {
+        runInEventLoop(
+                () -> {
+                    this.listeners.remove(operatorID);
+                    if (listeners.isEmpty()) {
+                        instances.remove(iterationId);
+                    }
+                },
+                "Unregister listeners %s",
+                operatorID.toHexString());
     }
 
     public void reportSubtaskProgress(
@@ -137,8 +141,8 @@ public class SharedProgressAligner {
                         GloballyAlignedEvent globallyAlignedEvent =
                                 new GloballyAlignedEvent(
                                         subtaskAlignedEvent.getEpoch(), roundStatus.isTerminated());
-                        for (Consumer<GloballyAlignedEvent> consumer : alignedConsumers.values()) {
-                            consumer.accept(globallyAlignedEvent);
+                        for (SharedProgressAlignerListener listeners : listeners.values()) {
+                            listeners.onAligned(globallyAlignedEvent);
                         }
                     }
                 },
@@ -147,6 +151,37 @@ public class SharedProgressAligner {
                 subtaskIndex);
     }
 
+    public void requestCheckpoint(
+            long checkpointId,
+            int operatorParallelism,
+            CompletableFuture<byte[]> snapshotStateFuture) {
+        runInEventLoop(
+                () -> {
+                    CheckpointStatus checkpointStatus =
+                            checkpointStatuses.computeIfAbsent(
+                                    checkpointId,
+                                    ignored -> new CheckpointStatus(totalHeadParallelism));
+                    boolean aligned =
+                            checkpointStatus.notify(operatorParallelism, snapshotStateFuture);
+                    if (aligned) {
+                        CoordinatorCheckpointEvent checkpointEvent =
+                                new CoordinatorCheckpointEvent(checkpointId);
+                        for (SharedProgressAlignerListener listener : listeners.values()) {
+                            listener.onCheckpointAligned(checkpointEvent);
+                        }
+
+                        for (CompletableFuture<byte[]> stateFuture :
+                                checkpointStatus.getStateFutures()) {
+                            stateFuture.complete(new byte[0]);
+                        }
+
+                        checkpointStatuses.remove(checkpointId);
+                    }
+                },
+                "Coordinator report checkpoint %d",
+                checkpointId);
+    }
+
     private void runInEventLoop(
             ThrowingRunnable<Throwable> action,
             String actionName,
@@ -170,8 +205,8 @@ public class SharedProgressAligner {
     }
 
     @VisibleForTesting
-    int getNumberConsumers() {
-        return alignedConsumers.size();
+    int getNumberListeners() {
+        return listeners.size();
     }
 
     private static class EpochStatus {
@@ -224,4 +259,28 @@ public class SharedProgressAligner {
             return totalRecord == 0 || (hasCriteriaStream && totalCriteriaRecord == 0);
         }
     }
+
+    private static class CheckpointStatus {
+
+        private final long totalHeadParallelism;
+
+        private final List<CompletableFuture<byte[]>> stateFutures = new ArrayList<>();
+
+        private int notifiedCoordinatorParallelism;
+
+        private CheckpointStatus(long totalHeadParallelism) {
+            this.totalHeadParallelism = totalHeadParallelism;
+        }
+
+        public boolean notify(int parallelism, CompletableFuture<byte[]> stateFuture) {
+            stateFutures.add(stateFuture);
+            notifiedCoordinatorParallelism += parallelism;
+
+            return notifiedCoordinatorParallelism == totalHeadParallelism;
+        }
+
+        public List<CompletableFuture<byte[]>> getStateFutures() {
+            return stateFutures;
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
new file mode 100644
index 0000000..6b687dd
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.coordinator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+
+/** The listener of the {@link SharedProgressAligner}. */
+public interface SharedProgressAlignerListener {
+
+    void onAligned(GloballyAlignedEvent globallyAlignedEvent);
+
+    void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent);
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
new file mode 100644
index 0000000..9c97581
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.event;
+
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+
+import java.util.Objects;
+
+/** Coordinator received the request of checkpoints. */
+public class CoordinatorCheckpointEvent implements OperatorEvent {
+
+    private final long checkpointId;
+
+    public CoordinatorCheckpointEvent(long checkpointId) {
+        this.checkpointId = checkpointId;
+    }
+
+    public long getCheckpointId() {
+        return checkpointId;
+    }
+
+    @Override
+    public String toString() {
+        return "CoordinatorCheckpointEvent{" + "checkpointId=" + checkpointId + '}';
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (!(o instanceof CoordinatorCheckpointEvent)) {
+            return false;
+        }
+        CoordinatorCheckpointEvent that = (CoordinatorCheckpointEvent) o;
+        return checkpointId == that.checkpointId;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(checkpointId);
+    }
+}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index 3ded30a..e6d1e26 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -21,15 +21,20 @@ package org.apache.flink.iteration.operator;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -152,7 +157,7 @@ public class HeadOperatorTest extends TestLogger {
 
                                             while (RecordingHeadOperatorFactory.latestHeadOperator
                                                             .getStatus()
-                                                    == HeadOperator.HeadOperatorStatus.RUNNING) ;
+                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {}
                                             putFeedbackRecords(
                                                     iterationId,
                                                     IterationRecord.newEpochWatermark(
@@ -197,6 +202,116 @@ public class HeadOperatorTest extends TestLogger {
                 });
     }
 
+    @Test(timeout = 60000)
+    public void testHoldCheckpointTillCoordinatorNotified() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    CompletableFuture<Void> coordinatorResult =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            // Slight postpone the notification
+                                            Thread.sleep(2000);
+
+                                            harness.getStreamTask()
+                                                    .dispatchOperatorEvent(
+                                                            operatorId,
+                                                            new SerializedValue<>(
+                                                                    new GloballyAlignedEvent(
+                                                                            5, false)));
+                                            harness.getStreamTask()
+                                                    .dispatchOperatorEvent(
+                                                            operatorId,
+                                                            new SerializedValue<>(
+                                                                    new CoordinatorCheckpointEvent(
+                                                                            5)));
+                                            return null;
+                                        } catch (Throwable e) {
+                                            RecordingHeadOperatorFactory.latestHeadOperator
+                                                    .getMailboxExecutor()
+                                                    .execute(
+                                                            () -> {
+                                                                throw e;
+                                                            },
+                                                            "poison mail");
+                                            throw new CompletionException(e);
+                                        }
+                                    });
+
+                    CheckpointBarrier barrier =
+                            new CheckpointBarrier(
+                                    5,
+                                    5000,
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processEvent(barrier);
+
+                    // There should be no exception
+                    coordinatorResult.get();
+
+                    // If the task do not hold, it would be likely snapshot state before received
+                    // the globally aligned event.
+                    assertEquals(
+                            Arrays.asList(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    5,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0),
+                                    barrier),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
+    @Test(timeout = 60000)
+    public void testPostponeGloballyAlignedEventsAfterSnapshot() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.getStreamTask()
+                            .dispatchOperatorEvent(
+                                    operatorId,
+                                    new SerializedValue<>(new CoordinatorCheckpointEvent(5)));
+                    harness.getStreamTask()
+                            .dispatchOperatorEvent(
+                                    operatorId,
+                                    new SerializedValue<>(new GloballyAlignedEvent(5, false)));
+                    CheckpointBarrier barrier =
+                            new CheckpointBarrier(
+                                    5,
+                                    5000,
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processEvent(barrier);
+                    harness.processAll();
+
+                    assertEquals(
+                            Arrays.asList(
+                                    barrier,
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    5,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0)),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
     private <T> T createHarnessAndRun(
             IterationID iterationId,
             OperatorID operatorId,
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
index 89ecba6..350f7bf 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -32,7 +33,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Consumer;
+import java.util.concurrent.CompletableFuture;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -64,22 +65,22 @@ public class SharedProgressAlignerTest extends TestLogger {
     }
 
     @Test
-    public void testRegisterAndUnregisterConsumers() {
+    public void testRegisterAndUnregisterListeners() {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
-        List<Consumer<GloballyAlignedEvent>> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<SharedProgressAlignerListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), consumers);
+                initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), listeners);
 
-        assertEquals(2, aligner.getNumberConsumers());
+        assertEquals(2, aligner.getNumberListeners());
 
-        aligner.unregisterConsumer(operatorIds.get(0));
-        assertEquals(1, aligner.getNumberConsumers());
+        aligner.unregisterListener(operatorIds.get(0));
+        assertEquals(1, aligner.getNumberListeners());
         assertTrue(SharedProgressAligner.getInstances().containsKey(iterationId));
 
-        aligner.unregisterConsumer(operatorIds.get(1));
-        assertEquals(0, aligner.getNumberConsumers());
+        aligner.unregisterListener(operatorIds.get(1));
+        assertEquals(0, aligner.getNumberListeners());
         assertFalse(SharedProgressAligner.getInstances().containsKey(iterationId));
     }
 
@@ -88,10 +89,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -100,8 +101,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, false)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, false)), listeners);
     }
 
     @Test
@@ -109,10 +110,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -121,8 +122,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
     }
 
     @Test
@@ -130,10 +131,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -142,8 +143,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(0, false)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(0, false)), listeners);
     }
 
     @Test
@@ -151,10 +152,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -164,15 +165,46 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
+    }
+
+    @Test
+    public void testSendEventsBeforeCompleteCheckpoint() {
+        IterationID iterationId = new IterationID();
+        List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
+        List<Integer> parallelisms = Arrays.asList(2, 3);
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
+        SharedProgressAligner aligner =
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
+
+        List<CompletableFuture<byte[]>> firstCheckpointStateFutures =
+                Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+        for (int i = 0; i < operatorIds.size(); ++i) {
+            // Operator 0 is the criteria stream
+            aligner.requestCheckpoint(1, parallelisms.get(i), firstCheckpointStateFutures.get(i));
+        }
+
+        List<CompletableFuture<byte[]>> secondCheckpointStateFutures =
+                Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+        for (int i = 0; i < operatorIds.size(); ++i) {
+            // Operator 0 is the criteria stream
+            aligner.requestCheckpoint(2, parallelisms.get(i), secondCheckpointStateFutures.get(i));
+        }
+
+        firstCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+        secondCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+        checkCoordinatorCheckpointEvents(
+                Arrays.asList(new CoordinatorCheckpointEvent(1), new CoordinatorCheckpointEvent(2)),
+                listeners);
     }
 
     private SharedProgressAligner initializeAligner(
             IterationID iterationId,
             List<OperatorID> operatorIds,
             List<Integer> parallelisms,
-            List<? extends Consumer<GloballyAlignedEvent>> consumers) {
+            List<? extends SharedProgressAlignerListener> listeners) {
 
         SharedProgressAligner aligner =
                 SharedProgressAligner.getOrCreate(
@@ -181,28 +213,43 @@ public class SharedProgressAlignerTest extends TestLogger {
                         new MockOperatorCoordinatorContext(operatorIds.get(0), parallelisms.get(0)),
                         DirectScheduledExecutorService::new);
 
-        for (int i = 0; i < consumers.size(); ++i) {
-            aligner.registerAlignedConsumer(operatorIds.get(i), consumers.get(i));
+        for (int i = 0; i < listeners.size(); ++i) {
+            aligner.registerAlignedListener(operatorIds.get(i), listeners.get(i));
         }
 
         return aligner;
     }
 
-    private void checkRecordingConsumers(
+    private void checkGloballyAlignedEvents(
             List<GloballyAlignedEvent> expectedGloballyAlignedEvents,
-            List<RecordingConsumer> consumers) {
-        for (RecordingConsumer consumer : consumers) {
+            List<RecordingListener> listeners) {
+        for (RecordingListener consumer : listeners) {
             assertEquals(expectedGloballyAlignedEvents, consumer.globallyAlignedEvents);
         }
     }
 
-    private static class RecordingConsumer implements Consumer<GloballyAlignedEvent> {
+    private void checkCoordinatorCheckpointEvents(
+            List<CoordinatorCheckpointEvent> expectedGloballyAlignedEvents,
+            List<RecordingListener> listeners) {
+        for (RecordingListener consumer : listeners) {
+            assertEquals(expectedGloballyAlignedEvents, consumer.checkpointEvents);
+        }
+    }
+
+    private static class RecordingListener implements SharedProgressAlignerListener {
 
         final List<GloballyAlignedEvent> globallyAlignedEvents = new ArrayList<>();
 
+        final List<CoordinatorCheckpointEvent> checkpointEvents = new ArrayList<>();
+
         @Override
-        public void accept(GloballyAlignedEvent globallyAlignedEvent) {
+        public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
             globallyAlignedEvents.add(globallyAlignedEvent);
         }
+
+        @Override
+        public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+            checkpointEvents.add(coordinatorCheckpointEvent);
+        }
     }
 }