You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:38 UTC

[flink-ml] branch master updated (b1253c0 -> b9ee412)

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git.


    from b1253c0  [FLINK-24653][iteration] Support per-round operators inside the iteration
     new 63a82c2  [FLINK-24655][iteration] HeadOperator waits for MAX_WATERMARK iterates back before terminating.
     new 9a15fb4  [hotfix][iteration] Simplify the head operator test
     new 5896237  [FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint
     new 7c99864  [FLINK-24655][iteration] Support snapshot the feedback records on checkpoint
     new 4b47e5d  [FLINK-24655][iteration] Support the checkpoints for the iteration
     new 31ffe6c  [FLINK-24655][iteration] Skip the repeat round for all-round operator
     new 9172ec7  [FLINK-24655][iteration] Do not rely on the precedent tasks to insert epoch watermark
     new b9ee412  [FLINK-24655][iteration] Add ITCase for the checkpoint and failover

The 8 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../org/apache/flink/iteration/Iterations.java     |  29 +-
 .../flink/iteration/checkpoint/Checkpoints.java    | 189 +++++
 .../iteration/checkpoint/CheckpointsBroker.java    |  57 ++
 .../datacache/nonkeyed/DataCacheSnapshot.java      | 224 ++++++
 .../datacache/nonkeyed/DataCacheWriter.java        |  20 +-
 .../iteration/datacache/nonkeyed/Segment.java      |   5 +
 .../operator/AbstractWrapperOperator.java          |  10 +-
 .../flink/iteration/operator/HeadOperator.java     | 436 ++++++++--
 .../operator/HeadOperatorCheckpointAligner.java    | 130 +++
 .../iteration/operator/HeadOperatorFactory.java    |   8 +
 .../flink/iteration/operator/InputOperator.java    |  20 +-
 .../iteration/operator/OperatorStateUtils.java     |  29 +-
 .../flink/iteration/operator/OperatorUtils.java    |   5 +-
 .../flink/iteration/operator/ReplayOperator.java   |  28 +-
 .../flink/iteration/operator/TailOperator.java     |  27 +
 .../allround/AbstractAllRoundWrapperOperator.java  | 164 +++-
 .../MultipleInputAllRoundWrapperOperator.java      |   2 +
 .../allround/TwoInputAllRoundWrapperOperator.java  |   2 +
 .../coordinator/HeadOperatorCoordinator.java       |  42 +-
 .../coordinator/SharedProgressAligner.java         | 135 +++-
 .../SharedProgressAlignerListener.java}            |  21 +-
 ...dEvent.java => CoordinatorCheckpointEvent.java} |  42 +-
 .../headprocessor/HeadOperatorRecordProcessor.java |  63 ++
 .../operator/headprocessor/HeadOperatorState.java  |  52 ++
 .../RegularHeadOperatorRecordProcessor.java        | 190 +++++
 .../TerminatingHeadOperatorRecordProcessor.java    |  63 ++
 .../OperatorEpochWatermarkTracker.java             |  24 +-
 .../flink/iteration/IterationConstructionTest.java |  40 +-
 .../datacache/nonkeyed/DataCacheSnapshotTest.java  | 213 +++++
 .../flink/iteration/operator/HeadOperatorTest.java | 881 ++++++++++++++++++---
 .../iteration/operator/InputOperatorTest.java      |  34 +-
 .../iteration/operator/ReplayOperatorTest.java     |  21 +-
 .../OneInputAllRoundWrapperOperatorTest.java       |  70 ++
 .../coordinator/SharedProgressAlignerTest.java     | 123 ++-
 .../OperatorEpochWatermarkTrackerTest.java         |  17 +
 .../iteration/BoundedAllRoundCheckpointTest.java   | 196 +++++
 .../iteration/UnboundedStreamIterationITCase.java  |   5 +-
 .../{EpochRecord.java => FailingMap.java}          |  38 +-
 .../operators/ReduceAllRoundProcessFunction.java   |  55 +-
 .../test/iteration/operators/SequenceSource.java   |  40 +-
 .../TwoInputReduceAllRoundProcessFunction.java     |  16 +-
 41 files changed, 3312 insertions(+), 454 deletions(-)
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
 copy flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/CollectSink.java => flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java (52%)
 copy flink-ml-iteration/src/main/java/org/apache/flink/iteration/{proxy/state/StateNamePrefix.java => operator/coordinator/SharedProgressAlignerListener.java} (63%)
 copy flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/{GloballyAlignedEvent.java => CoordinatorCheckpointEvent.java} (54%)
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
 create mode 100644 flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
 create mode 100644 flink-ml-iteration/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshotTest.java
 create mode 100644 flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
 copy flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/{EpochRecord.java => FailingMap.java} (56%)

[flink-ml] 03/08: [FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 5896237262fc1327ab465dbbaeceaa8402c23c09
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Oct 4 21:20:49 2021 +0800

    [FLINK-24655][iteration] Make head operator aligned with coordinator for each checkpoint
---
 .../flink/iteration/operator/HeadOperator.java     |  63 +++++++++--
 .../operator/HeadOperatorCheckpointAligner.java    | 104 +++++++++++++++++
 .../coordinator/HeadOperatorCoordinator.java       |  36 +++---
 .../coordinator/SharedProgressAligner.java         | 107 ++++++++++++++----
 .../coordinator/SharedProgressAlignerListener.java |  30 +++++
 .../operator/event/CoordinatorCheckpointEvent.java |  59 ++++++++++
 .../flink/iteration/operator/HeadOperatorTest.java | 117 +++++++++++++++++++-
 .../coordinator/SharedProgressAlignerTest.java     | 123 ++++++++++++++-------
 8 files changed, 554 insertions(+), 85 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index d7a9a54..6e5a7b4 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -23,9 +23,11 @@ import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
@@ -36,6 +38,7 @@ import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
 import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
@@ -49,6 +52,7 @@ import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.OutputTag;
 
@@ -59,8 +63,23 @@ import java.util.concurrent.Executor;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
- * The head operator unions the initialized variable stream and the feedback stream, and synchronize
- * the epoch watermark (round).
+ * The head operator unions the initialized variable stream and the feedback stream, synchronize the
+ * epoch watermark (round) and taking care of the checkpoints.
+ *
+ * <p>Specially for checkpoint, the head operator would like to
+ *
+ * <ul>
+ *   <li>Ensures the exactly-once for processing elements.
+ *   <li>Ensures the exactly-once for {@link IterationListener#onEpochWatermarkIncremented(int,
+ *       IterationListener.Context, Collector)}.
+ * </ul>
+ *
+ * <p>To implement the first target, the head operator also need to include the records between
+ * alignment and received barrier from the feed-back edge into the snapshot. To implement the second
+ * target, the head operator would also wait for the notification from the OperatorCoordinator in
+ * additional to the task inputs. This ensures the {@link GloballyAlignedEvent} would not interleave
+ * with the epoch watermarks and all the tasks inside the iteration would be notified with the same
+ * epochs, which facility the rescaling in the future.
  */
 public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         implements OneInputStreamOperator<IterationRecord<?>, IterationRecord<?>>,
@@ -91,6 +110,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     private HeadOperatorRecordProcessor recordProcessor;
 
+    private HeadOperatorCheckpointAligner checkpointAligner;
+
     public HeadOperator(
             IterationID iterationId,
             int feedbackIndex,
@@ -129,6 +150,8 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         status = HeadOperatorStatus.RUNNING;
         recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
 
+        checkpointAligner = new HeadOperatorCheckpointAligner();
+
         // Here we register a mail
         registerFeedbackConsumer(
                 (Runnable runnable) -> {
@@ -139,6 +162,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     }
 
     @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        super.prepareSnapshotPreBarrier(checkpointId);
+
+        checkpointAligner.waitTillCoordinatorNotified(checkpointId, mailboxExecutor::yield);
+    }
+
+    @Override
+    public void snapshotState(StateSnapshotContext context) throws Exception {
+        super.snapshotState(context);
+        checkpointAligner
+                .onStateSnapshot(context.getCheckpointId())
+                .forEach(this::processGloballyAlignedEvent);
+    }
+
+    @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
         recordProcessor.processElement(element);
     }
@@ -155,12 +193,21 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     @Override
     public void handleOperatorEvent(OperatorEvent operatorEvent) {
         if (operatorEvent instanceof GloballyAlignedEvent) {
-            boolean shouldTerminate =
-                    recordProcessor.onGloballyAligned((GloballyAlignedEvent) operatorEvent);
-            if (shouldTerminate) {
-                status = HeadOperatorStatus.TERMINATING;
-                recordProcessor = new TerminatingHeadOperatorRecordProcessor();
-            }
+            checkpointAligner
+                    .checkHoldingGloballyAlignedEvent((GloballyAlignedEvent) operatorEvent)
+                    .ifPresent(this::processGloballyAlignedEvent);
+        } else if (operatorEvent instanceof CoordinatorCheckpointEvent) {
+            checkpointAligner.coordinatorNotify((CoordinatorCheckpointEvent) operatorEvent);
+        } else {
+            throw new FlinkRuntimeException("Unsupported operator event: " + operatorEvent);
+        }
+    }
+
+    private void processGloballyAlignedEvent(GloballyAlignedEvent globallyAlignedEvent) {
+        boolean shouldTerminate = recordProcessor.onGloballyAligned(globallyAlignedEvent);
+        if (shouldTerminate) {
+            status = HeadOperatorStatus.TERMINATING;
+            recordProcessor = new TerminatingHeadOperatorRecordProcessor();
         }
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
new file mode 100644
index 0000000..b4f4215
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.util.function.RunnableWithException;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.TreeMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Aligns the checkpoint barrier from the task inputs and the checkpoint event from the coordinator.
+ * Besides, it needs to hold the other operator events after the checkpoint event till the state is
+ * snapshot.
+ */
+class HeadOperatorCheckpointAligner {
+
+    private final TreeMap<Long, CheckpointAlignment> checkpointAlignmments;
+
+    private long latestCheckpointFromCoordinator;
+
+    HeadOperatorCheckpointAligner() {
+        this.checkpointAlignmments = new TreeMap<>();
+    }
+
+    void waitTillCoordinatorNotified(long checkpointId, RunnableWithException defaultAction)
+            throws Exception {
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.computeIfAbsent(
+                        checkpointId, ignored -> new CheckpointAlignment(true, false));
+        while (!checkpointAlignment.notifiedFromCoordinator) {
+            defaultAction.run();
+        }
+        checkpointAlignment.notifiedFromChannels = true;
+    }
+
+    void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
+        checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
+        latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.computeIfAbsent(
+                        checkpointEvent.getCheckpointId(),
+                        ignored -> new CheckpointAlignment(false, true));
+        checkpointAlignment.notifiedFromCoordinator = true;
+    }
+
+    Optional<GloballyAlignedEvent> checkHoldingGloballyAlignedEvent(
+            GloballyAlignedEvent globallyAlignedEvent) {
+        CheckpointAlignment checkpointAlignment =
+                checkpointAlignmments.get(latestCheckpointFromCoordinator);
+        if (checkpointAlignment != null && !checkpointAlignment.notifiedFromChannels) {
+            checkpointAlignment.pendingGlobalEvents.add(globallyAlignedEvent);
+            return Optional.empty();
+        }
+
+        return Optional.of(globallyAlignedEvent);
+    }
+
+    List<GloballyAlignedEvent> onStateSnapshot(long checkpointId) {
+        CheckpointAlignment checkpointAlignment = checkpointAlignmments.remove(checkpointId);
+        checkState(
+                checkpointAlignment.notifiedFromCoordinator
+                        && checkpointAlignment.notifiedFromChannels,
+                "Checkpoint " + checkpointId + " is not fully aligned");
+        return checkpointAlignment.pendingGlobalEvents;
+    }
+
+    private static class CheckpointAlignment {
+
+        final List<GloballyAlignedEvent> pendingGlobalEvents;
+
+        boolean notifiedFromChannels;
+
+        boolean notifiedFromCoordinator;
+
+        public CheckpointAlignment(boolean notifiedFromChannels, boolean notifiedFromCoordinator) {
+            this.pendingGlobalEvents = new ArrayList<>();
+
+            this.notifiedFromChannels = notifiedFromChannels;
+            this.notifiedFromCoordinator = notifiedFromCoordinator;
+        }
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index bde35a2..2c01c33 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.operator.HeadOperator;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -37,7 +38,7 @@ import java.util.concurrent.Executors;
  * SharedProgressAligner} when received aligned event from the operator, and emit the globally
  * aligned event back after one round is globally aligned.
  */
-public class HeadOperatorCoordinator implements OperatorCoordinator {
+public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgressAlignerListener {
 
     private final Context context;
 
@@ -50,18 +51,24 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
         this.sharedProgressAligner = Objects.requireNonNull(sharedProgressAligner);
         this.subtaskGateways = new SubtaskGateway[context.currentParallelism()];
 
-        sharedProgressAligner.registerAlignedConsumer(context.getOperatorId(), this::onAligned);
+        sharedProgressAligner.registerAlignedListener(context.getOperatorId(), this);
     }
 
     @Override
     public void start() {}
 
     @Override
-    public void subtaskReady(int i, SubtaskGateway subtaskGateway) {
-        this.subtaskGateways[i] = subtaskGateway;
+    public void subtaskReady(int subtaskIndex, SubtaskGateway subtaskGateway) {
+        this.subtaskGateways[subtaskIndex] = subtaskGateway;
     }
 
     @Override
+    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+
+    @Override
+    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+
+    @Override
     public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
         if (operatorEvent instanceof SubtaskAlignedEvent) {
             sharedProgressAligner.reportSubtaskProgress(
@@ -71,6 +78,11 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
         }
     }
 
+    @Override
+    public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
+        sharedProgressAligner.requestCheckpoint(l, context.currentParallelism(), completableFuture);
+    }
+
     public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
         for (int i = 0; i < context.currentParallelism(); ++i) {
             subtaskGateways[i].sendEvent(globallyAlignedEvent);
@@ -78,25 +90,21 @@ public class HeadOperatorCoordinator implements OperatorCoordinator {
     }
 
     @Override
-    public void close() {
-        sharedProgressAligner.unregisterConsumer(context.getOperatorId());
+    public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+        for (int i = 0; i < context.currentParallelism(); ++i) {
+            subtaskGateways[i].sendEvent(coordinatorCheckpointEvent);
+        }
     }
 
     @Override
-    public void checkpointCoordinator(long l, CompletableFuture<byte[]> completableFuture) {
-        completableFuture.complete(new byte[0]);
+    public void close() {
+        sharedProgressAligner.unregisterListener(context.getOperatorId());
     }
 
     @Override
     public void notifyCheckpointComplete(long l) {}
 
     @Override
-    public void resetToCheckpoint(long l, @Nullable byte[] bytes) {}
-
-    @Override
-    public void subtaskFailed(int i, @Nullable Throwable throwable) {}
-
-    @Override
     public void subtaskReset(int i, long l) {}
 
     /** The factory of {@link HeadOperatorCoordinator}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 6d03a5e..9c2fd43 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -31,12 +32,14 @@ import org.apache.flink.util.function.ThrowingRunnable;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
+import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
-import java.util.function.Consumer;
 import java.util.function.Supplier;
 
 import static org.apache.flink.util.Preconditions.checkState;
@@ -44,7 +47,7 @@ import static org.apache.flink.util.Preconditions.checkState;
 /**
  * The progress aligner shared between multiple {@link HeadOperatorCoordinator}. It maintains the
  * information for each round, once one round is aligned, it would notify all the register
- * consumers.
+ * listeners.
  */
 public class SharedProgressAligner {
 
@@ -63,7 +66,9 @@ public class SharedProgressAligner {
 
     private final Map<Integer, EpochStatus> statusByEpoch;
 
-    private final Map<OperatorID, Consumer<GloballyAlignedEvent>> alignedConsumers;
+    private final Map<OperatorID, SharedProgressAlignerListener> listeners;
+
+    private final Map<Long, CheckpointStatus> checkpointStatuses;
 
     public static SharedProgressAligner getOrCreate(
             IterationID iterationId,
@@ -93,29 +98,28 @@ public class SharedProgressAligner {
         this.executor = Objects.requireNonNull(executor);
 
         this.statusByEpoch = new HashMap<>();
-        this.alignedConsumers = new HashMap<>();
+        this.listeners = new HashMap<>();
+        this.checkpointStatuses = new HashMap<>();
     }
 
-    public void registerAlignedConsumer(
-            OperatorID operatorID, Consumer<GloballyAlignedEvent> alignedConsumer) {
+    public void registerAlignedListener(
+            OperatorID operatorID, SharedProgressAlignerListener alignedConsumer) {
         runInEventLoop(
-                () -> this.alignedConsumers.put(operatorID, alignedConsumer),
-                "Register consumer %s",
+                () -> this.listeners.put(operatorID, alignedConsumer),
+                "Register listeners %s",
                 operatorID.toHexString());
     }
 
-    public void unregisterConsumer(OperatorID operatorID) {
-        synchronized (this) {
-            runInEventLoop(
-                    () -> {
-                        this.alignedConsumers.remove(operatorID);
-                        if (alignedConsumers.isEmpty()) {
-                            instances.remove(iterationId);
-                        }
-                    },
-                    "Unregister consumer %s",
-                    operatorID.toHexString());
-        }
+    public void unregisterListener(OperatorID operatorID) {
+        runInEventLoop(
+                () -> {
+                    this.listeners.remove(operatorID);
+                    if (listeners.isEmpty()) {
+                        instances.remove(iterationId);
+                    }
+                },
+                "Unregister listeners %s",
+                operatorID.toHexString());
     }
 
     public void reportSubtaskProgress(
@@ -137,8 +141,8 @@ public class SharedProgressAligner {
                         GloballyAlignedEvent globallyAlignedEvent =
                                 new GloballyAlignedEvent(
                                         subtaskAlignedEvent.getEpoch(), roundStatus.isTerminated());
-                        for (Consumer<GloballyAlignedEvent> consumer : alignedConsumers.values()) {
-                            consumer.accept(globallyAlignedEvent);
+                        for (SharedProgressAlignerListener listeners : listeners.values()) {
+                            listeners.onAligned(globallyAlignedEvent);
                         }
                     }
                 },
@@ -147,6 +151,37 @@ public class SharedProgressAligner {
                 subtaskIndex);
     }
 
+    public void requestCheckpoint(
+            long checkpointId,
+            int operatorParallelism,
+            CompletableFuture<byte[]> snapshotStateFuture) {
+        runInEventLoop(
+                () -> {
+                    CheckpointStatus checkpointStatus =
+                            checkpointStatuses.computeIfAbsent(
+                                    checkpointId,
+                                    ignored -> new CheckpointStatus(totalHeadParallelism));
+                    boolean aligned =
+                            checkpointStatus.notify(operatorParallelism, snapshotStateFuture);
+                    if (aligned) {
+                        CoordinatorCheckpointEvent checkpointEvent =
+                                new CoordinatorCheckpointEvent(checkpointId);
+                        for (SharedProgressAlignerListener listener : listeners.values()) {
+                            listener.onCheckpointAligned(checkpointEvent);
+                        }
+
+                        for (CompletableFuture<byte[]> stateFuture :
+                                checkpointStatus.getStateFutures()) {
+                            stateFuture.complete(new byte[0]);
+                        }
+
+                        checkpointStatuses.remove(checkpointId);
+                    }
+                },
+                "Coordinator report checkpoint %d",
+                checkpointId);
+    }
+
     private void runInEventLoop(
             ThrowingRunnable<Throwable> action,
             String actionName,
@@ -170,8 +205,8 @@ public class SharedProgressAligner {
     }
 
     @VisibleForTesting
-    int getNumberConsumers() {
-        return alignedConsumers.size();
+    int getNumberListeners() {
+        return listeners.size();
     }
 
     private static class EpochStatus {
@@ -224,4 +259,28 @@ public class SharedProgressAligner {
             return totalRecord == 0 || (hasCriteriaStream && totalCriteriaRecord == 0);
         }
     }
+
+    private static class CheckpointStatus {
+
+        private final long totalHeadParallelism;
+
+        private final List<CompletableFuture<byte[]>> stateFutures = new ArrayList<>();
+
+        private int notifiedCoordinatorParallelism;
+
+        private CheckpointStatus(long totalHeadParallelism) {
+            this.totalHeadParallelism = totalHeadParallelism;
+        }
+
+        public boolean notify(int parallelism, CompletableFuture<byte[]> stateFuture) {
+            stateFutures.add(stateFuture);
+            notifiedCoordinatorParallelism += parallelism;
+
+            return notifiedCoordinatorParallelism == totalHeadParallelism;
+        }
+
+        public List<CompletableFuture<byte[]>> getStateFutures() {
+            return stateFutures;
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
new file mode 100644
index 0000000..6b687dd
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerListener.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.coordinator;
+
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+
+/** The listener of the {@link SharedProgressAligner}. */
+public interface SharedProgressAlignerListener {
+
+    void onAligned(GloballyAlignedEvent globallyAlignedEvent);
+
+    void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent);
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
new file mode 100644
index 0000000..9c97581
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/event/CoordinatorCheckpointEvent.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.event;
+
+import org.apache.flink.runtime.operators.coordination.OperatorEvent;
+
+import java.util.Objects;
+
+/** Coordinator received the request of checkpoints. */
+public class CoordinatorCheckpointEvent implements OperatorEvent {
+
+    private final long checkpointId;
+
+    public CoordinatorCheckpointEvent(long checkpointId) {
+        this.checkpointId = checkpointId;
+    }
+
+    public long getCheckpointId() {
+        return checkpointId;
+    }
+
+    @Override
+    public String toString() {
+        return "CoordinatorCheckpointEvent{" + "checkpointId=" + checkpointId + '}';
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (!(o instanceof CoordinatorCheckpointEvent)) {
+            return false;
+        }
+        CoordinatorCheckpointEvent that = (CoordinatorCheckpointEvent) o;
+        return checkpointId == that.checkpointId;
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(checkpointId);
+    }
+}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index 3ded30a..e6d1e26 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -21,15 +21,20 @@ package org.apache.flink.iteration.operator;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -152,7 +157,7 @@ public class HeadOperatorTest extends TestLogger {
 
                                             while (RecordingHeadOperatorFactory.latestHeadOperator
                                                             .getStatus()
-                                                    == HeadOperator.HeadOperatorStatus.RUNNING) ;
+                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {}
                                             putFeedbackRecords(
                                                     iterationId,
                                                     IterationRecord.newEpochWatermark(
@@ -197,6 +202,116 @@ public class HeadOperatorTest extends TestLogger {
                 });
     }
 
+    @Test(timeout = 60000)
+    public void testHoldCheckpointTillCoordinatorNotified() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    CompletableFuture<Void> coordinatorResult =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            // Slight postpone the notification
+                                            Thread.sleep(2000);
+
+                                            harness.getStreamTask()
+                                                    .dispatchOperatorEvent(
+                                                            operatorId,
+                                                            new SerializedValue<>(
+                                                                    new GloballyAlignedEvent(
+                                                                            5, false)));
+                                            harness.getStreamTask()
+                                                    .dispatchOperatorEvent(
+                                                            operatorId,
+                                                            new SerializedValue<>(
+                                                                    new CoordinatorCheckpointEvent(
+                                                                            5)));
+                                            return null;
+                                        } catch (Throwable e) {
+                                            RecordingHeadOperatorFactory.latestHeadOperator
+                                                    .getMailboxExecutor()
+                                                    .execute(
+                                                            () -> {
+                                                                throw e;
+                                                            },
+                                                            "poison mail");
+                                            throw new CompletionException(e);
+                                        }
+                                    });
+
+                    CheckpointBarrier barrier =
+                            new CheckpointBarrier(
+                                    5,
+                                    5000,
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processEvent(barrier);
+
+                    // There should be no exception
+                    coordinatorResult.get();
+
+                    // If the task do not hold, it would be likely snapshot state before received
+                    // the globally aligned event.
+                    assertEquals(
+                            Arrays.asList(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    5,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0),
+                                    barrier),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
+    @Test(timeout = 60000)
+    public void testPostponeGloballyAlignedEventsAfterSnapshot() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.getStreamTask()
+                            .dispatchOperatorEvent(
+                                    operatorId,
+                                    new SerializedValue<>(new CoordinatorCheckpointEvent(5)));
+                    harness.getStreamTask()
+                            .dispatchOperatorEvent(
+                                    operatorId,
+                                    new SerializedValue<>(new GloballyAlignedEvent(5, false)));
+                    CheckpointBarrier barrier =
+                            new CheckpointBarrier(
+                                    5,
+                                    5000,
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processEvent(barrier);
+                    harness.processAll();
+
+                    assertEquals(
+                            Arrays.asList(
+                                    barrier,
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    5,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0)),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
     private <T> T createHarnessAndRun(
             IterationID iterationId,
             OperatorID operatorId,
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
index 89ecba6..350f7bf 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAlignerTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.coordinator;
 
 import org.apache.flink.iteration.IterationID;
+import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -32,7 +33,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.Consumer;
+import java.util.concurrent.CompletableFuture;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
@@ -64,22 +65,22 @@ public class SharedProgressAlignerTest extends TestLogger {
     }
 
     @Test
-    public void testRegisterAndUnregisterConsumers() {
+    public void testRegisterAndUnregisterListeners() {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
-        List<Consumer<GloballyAlignedEvent>> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<SharedProgressAlignerListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), consumers);
+                initializeAligner(iterationId, operatorIds, Arrays.asList(2, 3), listeners);
 
-        assertEquals(2, aligner.getNumberConsumers());
+        assertEquals(2, aligner.getNumberListeners());
 
-        aligner.unregisterConsumer(operatorIds.get(0));
-        assertEquals(1, aligner.getNumberConsumers());
+        aligner.unregisterListener(operatorIds.get(0));
+        assertEquals(1, aligner.getNumberListeners());
         assertTrue(SharedProgressAligner.getInstances().containsKey(iterationId));
 
-        aligner.unregisterConsumer(operatorIds.get(1));
-        assertEquals(0, aligner.getNumberConsumers());
+        aligner.unregisterListener(operatorIds.get(1));
+        assertEquals(0, aligner.getNumberListeners());
         assertFalse(SharedProgressAligner.getInstances().containsKey(iterationId));
     }
 
@@ -88,10 +89,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -100,8 +101,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, false)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, false)), listeners);
     }
 
     @Test
@@ -109,10 +110,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -121,8 +122,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
     }
 
     @Test
@@ -130,10 +131,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -142,8 +143,8 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(0, false)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(0, false)), listeners);
     }
 
     @Test
@@ -151,10 +152,10 @@ public class SharedProgressAlignerTest extends TestLogger {
         IterationID iterationId = new IterationID();
         List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
         List<Integer> parallelisms = Arrays.asList(2, 3);
-        List<RecordingConsumer> consumers =
-                Arrays.asList(new RecordingConsumer(), new RecordingConsumer());
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
         SharedProgressAligner aligner =
-                initializeAligner(iterationId, operatorIds, parallelisms, consumers);
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
 
         for (int i = 0; i < operatorIds.size(); ++i) {
             for (int j = 0; j < parallelisms.get(i); ++j) {
@@ -164,15 +165,46 @@ public class SharedProgressAlignerTest extends TestLogger {
             }
         }
 
-        checkRecordingConsumers(
-                Collections.singletonList(new GloballyAlignedEvent(2, true)), consumers);
+        this.checkGloballyAlignedEvents(
+                Collections.singletonList(new GloballyAlignedEvent(2, true)), listeners);
+    }
+
+    @Test
+    public void testSendEventsBeforeCompleteCheckpoint() {
+        IterationID iterationId = new IterationID();
+        List<OperatorID> operatorIds = Arrays.asList(new OperatorID(), new OperatorID());
+        List<Integer> parallelisms = Arrays.asList(2, 3);
+        List<RecordingListener> listeners =
+                Arrays.asList(new RecordingListener(), new RecordingListener());
+        SharedProgressAligner aligner =
+                initializeAligner(iterationId, operatorIds, parallelisms, listeners);
+
+        List<CompletableFuture<byte[]>> firstCheckpointStateFutures =
+                Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+        for (int i = 0; i < operatorIds.size(); ++i) {
+            // Operator 0 is the criteria stream
+            aligner.requestCheckpoint(1, parallelisms.get(i), firstCheckpointStateFutures.get(i));
+        }
+
+        List<CompletableFuture<byte[]>> secondCheckpointStateFutures =
+                Arrays.asList(new CompletableFuture<>(), new CompletableFuture<>());
+        for (int i = 0; i < operatorIds.size(); ++i) {
+            // Operator 0 is the criteria stream
+            aligner.requestCheckpoint(2, parallelisms.get(i), secondCheckpointStateFutures.get(i));
+        }
+
+        firstCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+        secondCheckpointStateFutures.forEach(future -> assertTrue(future.isDone()));
+        checkCoordinatorCheckpointEvents(
+                Arrays.asList(new CoordinatorCheckpointEvent(1), new CoordinatorCheckpointEvent(2)),
+                listeners);
     }
 
     private SharedProgressAligner initializeAligner(
             IterationID iterationId,
             List<OperatorID> operatorIds,
             List<Integer> parallelisms,
-            List<? extends Consumer<GloballyAlignedEvent>> consumers) {
+            List<? extends SharedProgressAlignerListener> listeners) {
 
         SharedProgressAligner aligner =
                 SharedProgressAligner.getOrCreate(
@@ -181,28 +213,43 @@ public class SharedProgressAlignerTest extends TestLogger {
                         new MockOperatorCoordinatorContext(operatorIds.get(0), parallelisms.get(0)),
                         DirectScheduledExecutorService::new);
 
-        for (int i = 0; i < consumers.size(); ++i) {
-            aligner.registerAlignedConsumer(operatorIds.get(i), consumers.get(i));
+        for (int i = 0; i < listeners.size(); ++i) {
+            aligner.registerAlignedListener(operatorIds.get(i), listeners.get(i));
         }
 
         return aligner;
     }
 
-    private void checkRecordingConsumers(
+    private void checkGloballyAlignedEvents(
             List<GloballyAlignedEvent> expectedGloballyAlignedEvents,
-            List<RecordingConsumer> consumers) {
-        for (RecordingConsumer consumer : consumers) {
+            List<RecordingListener> listeners) {
+        for (RecordingListener consumer : listeners) {
             assertEquals(expectedGloballyAlignedEvents, consumer.globallyAlignedEvents);
         }
     }
 
-    private static class RecordingConsumer implements Consumer<GloballyAlignedEvent> {
+    private void checkCoordinatorCheckpointEvents(
+            List<CoordinatorCheckpointEvent> expectedGloballyAlignedEvents,
+            List<RecordingListener> listeners) {
+        for (RecordingListener consumer : listeners) {
+            assertEquals(expectedGloballyAlignedEvents, consumer.checkpointEvents);
+        }
+    }
+
+    private static class RecordingListener implements SharedProgressAlignerListener {
 
         final List<GloballyAlignedEvent> globallyAlignedEvents = new ArrayList<>();
 
+        final List<CoordinatorCheckpointEvent> checkpointEvents = new ArrayList<>();
+
         @Override
-        public void accept(GloballyAlignedEvent globallyAlignedEvent) {
+        public void onAligned(GloballyAlignedEvent globallyAlignedEvent) {
             globallyAlignedEvents.add(globallyAlignedEvent);
         }
+
+        @Override
+        public void onCheckpointAligned(CoordinatorCheckpointEvent coordinatorCheckpointEvent) {
+            checkpointEvents.add(coordinatorCheckpointEvent);
+        }
     }
 }

[flink-ml] 02/08: [hotfix][iteration] Simplify the head operator test

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9a15fb43b9e3aaa9322fb9e5f37b181f1813fda7
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Wed Oct 6 18:36:29 2021 +0800

    [hotfix][iteration] Simplify the head operator test
---
 .../flink/iteration/operator/HeadOperatorTest.java | 281 ++++++++++++---------
 1 file changed, 159 insertions(+), 122 deletions(-)

diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index f54422e..3ded30a 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -25,9 +25,9 @@ import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
-import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
@@ -38,11 +38,16 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
+import org.apache.flink.util.function.FunctionWithException;
 
 import org.junit.Test;
 
+import javax.annotation.Nullable;
+
+import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -62,144 +67,173 @@ public class HeadOperatorTest extends TestLogger {
     @Test
     public void testForwardRecords() throws Exception {
         IterationID iterationId = new IterationID();
-        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
-                new StreamTaskMailboxTestHarnessBuilder<>(
-                                OneInputStreamTask::new,
-                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .setupOutputForSingletonOperatorChain(
-                                new RecordingHeadOperatorFactory(
-                                        iterationId, 0, false, 5, MockOperatorEventGateway::new))
-                        .build()) {
-            harness.processElement(new StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
-            putFeedbackRecords(
-                    iterationId, 0, new StreamRecord<>(IterationRecord.newRecord(3, 1), 3));
-            harness.processAll();
-            harness.processElement(new StreamRecord<>(IterationRecord.newRecord(2, 0), 3));
-            putFeedbackRecords(
-                    iterationId, 0, new StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
-            harness.processAll();
-
-            List<StreamRecord<IterationRecord<Integer>>> expectedOutput =
-                    Arrays.asList(
-                            new StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
-                            new StreamRecord<>(IterationRecord.newRecord(3, 1), 3),
-                            new StreamRecord<>(IterationRecord.newRecord(2, 0), 3),
-                            new StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
-            assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
-
-            RegularHeadOperatorRecordProcessor recordProcessor =
-                    (RegularHeadOperatorRecordProcessor)
-                            RecordingHeadOperatorFactory.latestHeadOperator.getRecordProcessor();
-
-            assertEquals(2, (long) recordProcessor.getNumFeedbackRecordsPerEpoch().get(1));
-        }
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
+                    putFeedbackRecords(iterationId, IterationRecord.newRecord(3, 1), 3L);
+                    harness.processAll();
+                    harness.processElement(new StreamRecord<>(IterationRecord.newRecord(2, 0), 3));
+                    putFeedbackRecords(iterationId, IterationRecord.newRecord(4, 1), 4L);
+                    harness.processAll();
+
+                    List<StreamRecord<IterationRecord<Integer>>> expectedOutput =
+                            Arrays.asList(
+                                    new StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
+                                    new StreamRecord<>(IterationRecord.newRecord(3, 1), 3),
+                                    new StreamRecord<>(IterationRecord.newRecord(2, 0), 3),
+                                    new StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
+                    assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
+
+                    RegularHeadOperatorRecordProcessor recordProcessor =
+                            (RegularHeadOperatorRecordProcessor)
+                                    RecordingHeadOperatorFactory.latestHeadOperator
+                                            .getRecordProcessor();
+
+                    assertEquals(2, (long) recordProcessor.getNumFeedbackRecordsPerEpoch().get(1));
+
+                    return null;
+                });
     }
 
     @Test(timeout = 60000)
     public void testSynchronizingEpochWatermark() throws Exception {
         IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
+
+                    // We will start a new thread to simulate the operator coordinator thread
+                    CompletableFuture<Void> taskExecuteResult =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            RecordingOperatorEventGateway eventGateway =
+                                                    (RecordingOperatorEventGateway)
+                                                            RecordingHeadOperatorFactory
+                                                                    .latestHeadOperator
+                                                                    .getOperatorEventGateway();
+
+                                            // We should get the aligned event for round 0 on
+                                            // endInput
+                                            assertNextOperatorEvent(
+                                                    new SubtaskAlignedEvent(0, 0, false),
+                                                    eventGateway);
+                                            dispatchOperatorEvent(
+                                                    harness,
+                                                    operatorId,
+                                                    new GloballyAlignedEvent(0, false));
+
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    IterationRecord.newRecord(4, 1),
+                                                    4L);
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    IterationRecord.newEpochWatermark(1, "tail"),
+                                                    0L);
+
+                                            assertNextOperatorEvent(
+                                                    new SubtaskAlignedEvent(1, 1, false),
+                                                    eventGateway);
+                                            dispatchOperatorEvent(
+                                                    harness,
+                                                    operatorId,
+                                                    new GloballyAlignedEvent(1, true));
+
+                                            while (RecordingHeadOperatorFactory.latestHeadOperator
+                                                            .getStatus()
+                                                    == HeadOperator.HeadOperatorStatus.RUNNING) ;
+                                            putFeedbackRecords(
+                                                    iterationId,
+                                                    IterationRecord.newEpochWatermark(
+                                                            Integer.MAX_VALUE + 1, "tail"),
+                                                    null);
+
+                                            return null;
+                                        } catch (Throwable e) {
+                                            RecordingHeadOperatorFactory.latestHeadOperator
+                                                    .getMailboxExecutor()
+                                                    .execute(
+                                                            () -> {
+                                                                throw e;
+                                                            },
+                                                            "poison mail");
+                                            throw new CompletionException(e);
+                                        }
+                                    });
+
+                    // Mark the input as finished.
+                    harness.processEvent(EndOfData.INSTANCE);
+
+                    // There should be no exception
+                    taskExecuteResult.get();
+
+                    assertEquals(
+                            Arrays.asList(
+                                    new StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    0,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0),
+                                    new StreamRecord<>(IterationRecord.newRecord(4, 1), 4),
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(
+                                                    Integer.MAX_VALUE,
+                                                    OperatorUtils.getUniqueSenderId(operatorId, 0)),
+                                            0)),
+                            new ArrayList<>(harness.getOutput()));
+                    return null;
+                });
+    }
+
+    private <T> T createHarnessAndRun(
+            IterationID iterationId,
+            OperatorID operatorId,
+            @Nullable TaskStateSnapshot snapshot,
+            FunctionWithException<
+                            StreamTaskMailboxTestHarness<IterationRecord<Integer>>, T, Exception>
+                    runnable)
+            throws Exception {
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
                 new StreamTaskMailboxTestHarnessBuilder<>(
                                 OneInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
                         .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setTaskStateSnapshot(
+                                1, snapshot == null ? new TaskStateSnapshot() : snapshot)
                         .setupOutputForSingletonOperatorChain(
                                 new RecordingHeadOperatorFactory(
                                         iterationId,
                                         0,
                                         false,
                                         5,
-                                        RecordingOperatorEventGateway::new))
+                                        RecordingOperatorEventGateway::new),
+                                operatorId)
                         .build()) {
-
-            OperatorID operatorId = RecordingHeadOperatorFactory.latestHeadOperator.getOperatorID();
-            harness.processElement(new StreamRecord<>(IterationRecord.newRecord(1, 0), 2));
-
-            // We will start a new thread to simulate the operator coordinator thread
-            CompletableFuture<Void> taskExecuteResult =
-                    CompletableFuture.supplyAsync(
-                            () -> {
-                                try {
-                                    RecordingOperatorEventGateway eventGateway =
-                                            (RecordingOperatorEventGateway)
-                                                    RecordingHeadOperatorFactory.latestHeadOperator
-                                                            .getOperatorEventGateway();
-
-                                    // We should get the aligned event for round 0 on endInput
-                                    assertNextOperatorEvent(
-                                            new SubtaskAlignedEvent(0, 0, false), eventGateway);
-                                    harness.getStreamTask()
-                                            .dispatchOperatorEvent(
-                                                    operatorId,
-                                                    new SerializedValue<>(
-                                                            new GloballyAlignedEvent(0, false)));
-
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new StreamRecord<>(
-                                                    IterationRecord.newEpochWatermark(1, "tail"),
-                                                    0));
-
-                                    assertNextOperatorEvent(
-                                            new SubtaskAlignedEvent(1, 1, false), eventGateway);
-                                    harness.getStreamTask()
-                                            .dispatchOperatorEvent(
-                                                    operatorId,
-                                                    new SerializedValue<>(
-                                                            new GloballyAlignedEvent(1, true)));
-
-                                    while (RecordingHeadOperatorFactory.latestHeadOperator
-                                                    .getStatus()
-                                            == HeadOperator.HeadOperatorStatus.RUNNING) {}
-                                    putFeedbackRecords(
-                                            iterationId,
-                                            0,
-                                            new StreamRecord<>(
-                                                    IterationRecord.newEpochWatermark(
-                                                            Integer.MAX_VALUE + 1, "tail")));
-
-                                    return null;
-                                } catch (Throwable e) {
-                                    RecordingHeadOperatorFactory.latestHeadOperator
-                                            .getMailboxExecutor()
-                                            .execute(
-                                                    () -> {
-                                                        throw e;
-                                                    },
-                                                    "poison mail");
-                                    throw new CompletionException(e);
-                                }
-                            });
-
-            // Mark the input as finished.
-            harness.processEvent(EndOfData.INSTANCE);
-
-            // There should be no exception
-            taskExecuteResult.get();
-
-            assertEquals(
-                    Arrays.asList(
-                            new StreamRecord<>(IterationRecord.newRecord(1, 0), 2),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            0, OperatorUtils.getUniqueSenderId(operatorId, 0)),
-                                    0),
-                            new StreamRecord<>(IterationRecord.newRecord(4, 1), 4),
-                            new StreamRecord<>(
-                                    IterationRecord.newEpochWatermark(
-                                            Integer.MAX_VALUE,
-                                            OperatorUtils.getUniqueSenderId(operatorId, 0)),
-                                    0)),
-                    new ArrayList<>(harness.getOutput()));
+            return runnable.apply(harness);
         }
     }
 
+    private static void dispatchOperatorEvent(
+            StreamTaskMailboxTestHarness<?> harness,
+            OperatorID operatorId,
+            OperatorEvent operatorEvent)
+            throws IOException, FlinkException {
+        harness.getStreamTask()
+                .dispatchOperatorEvent(operatorId, new SerializedValue<>(operatorEvent));
+    }
+
     private static void assertNextOperatorEvent(
             OperatorEvent expectedEvent, RecordingOperatorEventGateway eventGateway)
             throws InterruptedException {
@@ -209,14 +243,17 @@ public class HeadOperatorTest extends TestLogger {
     }
 
     private static void putFeedbackRecords(
-            IterationID iterationId, int feedbackIndex, StreamRecord<IterationRecord<?>> record) {
+            IterationID iterationId, IterationRecord<?> record, @Nullable Long timestamp) {
         FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
                 FeedbackChannelBroker.get()
                         .getChannel(
                                 OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
-                                                iterationId, feedbackIndex)
+                                                iterationId, 0)
                                         .withSubTaskIndex(0, 0));
-        feedbackChannel.put(record);
+        feedbackChannel.put(
+                timestamp == null
+                        ? new StreamRecord<>(record)
+                        : new StreamRecord<>(record, timestamp));
     }
 
     private static class RecordingOperatorEventGateway implements OperatorEventGateway {

[flink-ml] 07/08: [FLINK-24655][iteration] Do not rely on the precedent tasks to insert epoch watermark

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9172ec762c797a697e71c17b401a94c3222dc506
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 10:47:30 2021 +0800

    [FLINK-24655][iteration] Do not rely on the precedent tasks to insert epoch watermark
    
    If the finished and failover, then it would be skipped and would not
    insert the epoch watermark again.
---
 .../org/apache/flink/iteration/Iterations.java     | 29 +++++++++---------
 .../operator/AbstractWrapperOperator.java          | 10 ++++++-
 .../flink/iteration/operator/InputOperator.java    | 20 ++-----------
 .../flink/iteration/operator/ReplayOperator.java   | 28 +++++++++++++++---
 .../allround/AbstractAllRoundWrapperOperator.java  | 21 +++++++------
 .../MultipleInputAllRoundWrapperOperator.java      |  2 ++
 .../allround/TwoInputAllRoundWrapperOperator.java  |  2 ++
 .../OperatorEpochWatermarkTracker.java             | 24 +++++++++++++--
 .../flink/iteration/IterationConstructionTest.java |  2 +-
 .../iteration/operator/InputOperatorTest.java      | 34 +---------------------
 .../iteration/operator/ReplayOperatorTest.java     | 21 +++++++------
 .../OperatorEpochWatermarkTrackerTest.java         | 17 +++++++++++
 12 files changed, 113 insertions(+), 97 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
index b334da2..9c9fa42 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -35,7 +35,7 @@ import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
 import org.apache.flink.util.Collector;
 
@@ -197,7 +197,7 @@ public class Iterations {
                         .stream()
                         .mapToInt(i -> i)
                         .sum();
-        DataStreamList initVariableInputs = addInputs(initVariableStreams, false);
+        DataStreamList initVariableInputs = addInputs(initVariableStreams);
         DataStreamList headStreams =
                 addHeads(
                         initVariableStreams,
@@ -207,7 +207,7 @@ public class Iterations {
                         false,
                         0);
 
-        DataStreamList dataStreamInputs = addInputs(dataStreams, true);
+        DataStreamList dataStreamInputs = addInputs(dataStreams);
         if (replayedDataStreamIndices.size() > 0) {
             dataStreamInputs =
                     addReplayer(
@@ -293,13 +293,13 @@ public class Iterations {
             // Notes that the HeadOperator would broadcast the globally aligned events,
             // thus the operator does not require emit to the sideoutput specially.
             DataStream<?> replayedInput =
-                    ((SingleOutputStreamOperator<IterationRecord<?>>) firstHeadStream)
-                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
-                            .map(x -> x, dataStreamInputs.get(i).getType())
-                            .setParallelism(firstHeadStream.getParallelism())
-                            .name("signal-change-typeinfo")
-                            .broadcast()
-                            .union(dataStreamInputs.get(i))
+                    dataStreamInputs
+                            .get(i)
+                            .connect(
+                                    ((SingleOutputStreamOperator<IterationRecord<?>>)
+                                                    firstHeadStream)
+                                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
+                                            .broadcast())
                             .transform(
                                     "Replayer-"
                                             + originalDataStreams
@@ -307,7 +307,7 @@ public class Iterations {
                                                     .getTransformation()
                                                     .getName(),
                                     dataStreamInputs.get(i).getType(),
-                                    (OneInputStreamOperator) new ReplayOperator<>())
+                                    (TwoInputStreamOperator) new ReplayOperator<>())
                             .setParallelism(dataStreamInputs.get(i).getParallelism());
             result.add(replayedInput);
         }
@@ -338,7 +338,7 @@ public class Iterations {
                         .name(terminationCriteria.getTransformation().getName())
                         .setParallelism(terminationCriteria.getParallelism());
         DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
-        DataStreamList criteriaInputs = addInputs(criteriaSources, false);
+        DataStreamList criteriaInputs = addInputs(criteriaSources);
         DataStreamList criteriaHeaders =
                 addHeads(
                         criteriaSources,
@@ -398,8 +398,7 @@ public class Iterations {
         return map(dataStreams, DataStream::getType);
     }
 
-    private static DataStreamList addInputs(
-            DataStreamList dataStreams, boolean insertMaxEpochWatermark) {
+    private static DataStreamList addInputs(DataStreamList dataStreams) {
         return new DataStreamList(
                 map(
                         dataStreams,
@@ -408,7 +407,7 @@ public class Iterations {
                                         .transform(
                                                 "input-" + dataStream.getTransformation().getName(),
                                                 new IterationRecordTypeInfo<>(dataStream.getType()),
-                                                new InputOperator(insertMaxEpochWatermark))
+                                                new InputOperator())
                                         .setParallelism(dataStream.getParallelism())));
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index 3aebce9..80a6682 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.metrics.groups.InternalOperatorMetricGroup;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
@@ -47,7 +48,9 @@ import static org.apache.flink.util.Preconditions.checkState;
 
 /** The base class of all the wrapper operators. It provides the alignment functionality. */
 public abstract class AbstractWrapperOperator<T>
-        implements StreamOperator<IterationRecord<T>>, OperatorEpochWatermarkTrackerListener {
+        implements StreamOperator<IterationRecord<T>>,
+                OperatorEpochWatermarkTrackerListener,
+                BoundedMultiInput {
 
     private static final Logger LOG = LoggerFactory.getLogger(AbstractWrapperOperator.class);
 
@@ -130,6 +133,11 @@ public abstract class AbstractWrapperOperator<T>
         epochWatermarkSupplier.set(null);
     }
 
+    @Override
+    public void endInput(int i) throws Exception {
+        epochWatermarkTracker.finish(i - 1);
+    }
+
     private InternalOperatorMetricGroup createOperatorMetricGroup(
             Environment environment, StreamConfig streamConfig) {
         try {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
index 5ef167e..b6908ff 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
@@ -20,21 +20,17 @@ package org.apache.flink.iteration.operator;
 
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 /** Input operator that wraps the user record into {@link IterationRecord}. */
 public class InputOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
-        implements OneInputStreamOperator<T, IterationRecord<T>>, BoundedOneInput {
-
-    private final boolean insertMaxEpochWatermark;
+        implements OneInputStreamOperator<T, IterationRecord<T>> {
 
     private transient StreamRecord<IterationRecord<T>> reusable;
 
-    public InputOperator(boolean insertMaxEpochWatermark) {
-        this.insertMaxEpochWatermark = insertMaxEpochWatermark;
+    public InputOperator() {
         this.chainingStrategy = ChainingStrategy.ALWAYS;
     }
 
@@ -50,16 +46,4 @@ public class InputOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
         reusable.getValue().setValue(streamRecord.getValue());
         output.collect(reusable);
     }
-
-    @Override
-    public void endInput() throws Exception {
-        if (insertMaxEpochWatermark) {
-            reusable.replace(
-                    IterationRecord.newEpochWatermark(
-                            Integer.MAX_VALUE,
-                            OperatorUtils.getUniqueSenderId(
-                                    getOperatorID(), getRuntimeContext().getIndexOfThisSubtask())));
-            output.collect(reusable);
-        }
-    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
index 2deb01e..16c6653 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
@@ -31,8 +31,9 @@ import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.ExceptionUtils;
@@ -46,8 +47,10 @@ import static org.apache.flink.util.Preconditions.checkState;
 
 /** Replays the data received in the round 0 in the following round. */
 public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
-        implements OneInputStreamOperator<IterationRecord<T>, IterationRecord<T>>,
-                OperatorEpochWatermarkTrackerListener {
+        implements TwoInputStreamOperator<
+                        IterationRecord<T>, IterationRecord<Void>, IterationRecord<T>>,
+                OperatorEpochWatermarkTrackerListener,
+                BoundedMultiInput {
 
     private OperatorEpochWatermarkTracker progressTracker;
 
@@ -115,7 +118,7 @@ public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>
     }
 
     @Override
-    public void processElement(StreamRecord<IterationRecord<T>> element) throws Exception {
+    public void processElement1(StreamRecord<IterationRecord<T>> element) throws Exception {
         switch (element.getValue().getType()) {
             case RECORD:
                 dataCacheWriter.addRecord(element.getValue().getValue());
@@ -132,6 +135,23 @@ public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>
     }
 
     @Override
+    public void processElement2(StreamRecord<IterationRecord<Void>> element) throws Exception {
+        if (element.getValue().getType() == IterationRecord.Type.EPOCH_WATERMARK) {
+            progressTracker.onEpochWatermark(
+                    1, element.getValue().getSender(), element.getValue().getEpoch());
+        } else {
+            throw new UnsupportedOperationException(
+                    "Not supported element type: " + element.getValue());
+        }
+    }
+
+    @Override
+    public void endInput(int i) throws Exception {
+        // The notification ranges from 1 to N while the track uses 0 to N -1.
+        progressTracker.finish(i - 1);
+    }
+
+    @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
         if (epochWatermark == 0) {
             // No need to replay for the round 0, it is output directly.
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index d3461a1..0ea742c 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -99,19 +99,18 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
-        if (epochWatermark <= latestEpochWatermark) {
-            return;
+        if (epochWatermark > latestEpochWatermark) {
+            latestEpochWatermark = epochWatermark;
+
+            setIterationContextRound(epochWatermark);
+            processOperatorOrUdfIfSatisfy(
+                    wrappedOperator,
+                    IterationListener.class,
+                    listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
+            clearIterationContextRound();
         }
-        latestEpochWatermark = epochWatermark;
 
-        setIterationContextRound(epochWatermark);
-        processOperatorOrUdfIfSatisfy(
-                wrappedOperator,
-                IterationListener.class,
-                listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
-        clearIterationContextRound();
-
-        // Broadcast the events.
+        // Always broadcasts the events.
         super.onEpochWatermarkIncrement(epochWatermark);
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index c975538..4d94298 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -80,6 +80,8 @@ public class MultipleInputAllRoundWrapperOperator<OUT>
 
     @Override
     public void endInput(int i) throws Exception {
+        super.endInput(i);
+
         if (wrappedOperator instanceof BoundedMultiInput) {
             ((BoundedMultiInput) wrappedOperator).endInput(i);
         }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index 2c633c1..bedcccf 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -113,6 +113,8 @@ public class TwoInputAllRoundWrapperOperator<IN1, IN2, OUT>
 
     @Override
     public void endInput(int i) throws Exception {
+        super.endInput(i);
+
         if (wrappedOperator instanceof BoundedMultiInput) {
             ((BoundedMultiInput) wrappedOperator).endInput(i);
         }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
index e2d716b..33ddee8 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
@@ -60,9 +60,21 @@ public class OperatorEpochWatermarkTracker {
         InputStatus inputStatus = inputStatuses.get(inputIndex);
         inputStatus.onUpdate(sender, epochWatermark);
 
-        if (inputStatus.getInputLowerBound() > allInputsLowerBound.getValue(inputIndex)) {
+        tryUpdateLowerBound(inputIndex);
+    }
+
+    public void finish(int inputIndex) throws IOException {
+        inputStatuses.get(inputIndex).finish();
+
+        tryUpdateLowerBound(inputIndex);
+    }
+
+    private void tryUpdateLowerBound(int changedInputIndex) throws IOException {
+        if (inputStatuses.get(changedInputIndex).getInputLowerBound()
+                > allInputsLowerBound.getValue(changedInputIndex)) {
             int oldLowerBound = allInputsLowerBound.getLowerBound();
-            allInputsLowerBound.updateValue(inputIndex, inputStatus.getInputLowerBound());
+            allInputsLowerBound.updateValue(
+                    changedInputIndex, inputStatuses.get(changedInputIndex).getInputLowerBound());
             if (allInputsLowerBound.getLowerBound() > oldLowerBound) {
                 progressTrackerListener.onEpochWatermarkIncrement(
                         allInputsLowerBound.getLowerBound());
@@ -95,6 +107,12 @@ public class OperatorEpochWatermarkTracker {
             allChannelsLowerBound.updateValue(index, epochWatermark);
         }
 
+        public void finish() {
+            for (int i = 0; i < numberOfChannels; ++i) {
+                allChannelsLowerBound.updateValue(i, Integer.MAX_VALUE);
+            }
+        }
+
         public int getInputLowerBound() {
             return allChannelsLowerBound.getLowerBound();
         }
@@ -122,7 +140,7 @@ public class OperatorEpochWatermarkTracker {
 
         public void updateValue(int channel, int value) {
             checkState(
-                    value > values[channel],
+                    value >= values[channel],
                     String.format(
                             "The channel %d received an outdated value %d, which currently is %d",
                             channel, value, values[channel]));
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index 3cd4bf2..f2ec465 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -378,7 +378,7 @@ public class IterationConstructionTest extends TestLogger {
                         /* 0 */ "Source: Variable -> input-Variable",
                         /* 1 */ "Source: Constant -> input-Constant",
                         /* 2 */ "Source: Termination -> input-Termination",
-                        /* 3 */ "head-Variable -> signal-change-typeinfo",
+                        /* 3 */ "head-Variable",
                         /* 4 */ "Replayer-Constant",
                         /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
                         /* 6 */ "Feedback",
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
index 03cba7a..bba3fec 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
@@ -34,7 +34,7 @@ public class InputOperatorTest extends TestLogger {
     @Test
     public void testWrapRecord() throws Exception {
         OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(false));
+                new OneInputStreamOperatorTestHarness<>(new InputOperator<>());
         testHarness.open();
 
         ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
@@ -46,36 +46,4 @@ public class InputOperatorTest extends TestLogger {
         TestHarnessUtil.assertOutputEquals(
                 "Output was not correct", expectedOutput, testHarness.getOutput());
     }
-
-    @Test
-    public void testInsertMaxEpochWatermarkIfSpecified() throws Exception {
-        OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(true));
-        testHarness.open();
-
-        testHarness.endInput();
-
-        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
-        expectedOutput.add(
-                new StreamRecord<>(
-                        IterationRecord.newEpochWatermark(
-                                Integer.MAX_VALUE,
-                                OperatorUtils.getUniqueSenderId(
-                                        testHarness.getOperator().getOperatorID(), 0))));
-
-        TestHarnessUtil.assertOutputEquals(
-                "Output was not correct", expectedOutput, testHarness.getOutput());
-    }
-
-    @Test
-    public void testNotInsertMaxEpochWatermarkIfSpecified() throws Exception {
-        OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(false));
-        testHarness.open();
-
-        testHarness.endInput();
-
-        TestHarnessUtil.assertOutputEquals(
-                "Output was not correct", new ConcurrentLinkedQueue<>(), testHarness.getOutput());
-    }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
index d17c78f..97d61ca 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
@@ -24,9 +24,9 @@ import org.apache.flink.iteration.config.IterationOptions;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
 
 import org.junit.Rule;
 import org.junit.Test;
@@ -52,9 +52,10 @@ public class ReplayOperatorTest {
 
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
                 new StreamTaskMailboxTestHarnessBuilder<>(
-                                OneInputStreamTask::new,
+                                TwoInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), 2)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), 1)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.VOID_TYPE_INFO), 1)
                         .setupOutputForSingletonOperatorChain(new ReplayOperator<>(), operatorId)
                         .buildUnrestored()) {
             harness.getStreamTask()
@@ -70,19 +71,16 @@ public class ReplayOperatorTest {
             for (int i = 0; i < numRecords; ++i) {
                 harness.processElement(new StreamRecord<>(IterationRecord.newRecord(i, 0)), 0, 0);
             }
+            harness.endInput(0, true);
             harness.processElement(
-                    new StreamRecord<>(
-                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "sender0")),
-                    0,
-                    0);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender1")), 1, 0);
             assertOutputAllRecordsAndEpochWatermark(harness.getOutput(), numRecords, operatorId, 0);
             harness.getOutput().clear();
 
             // The round 1
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender1")), 1, 0);
+            // The output would be done asynchronously inside the ReplayerOperator.
             while (harness.getOutput().size() < numRecords + 1) {
                 Thread.sleep(500);
             }
@@ -91,7 +89,8 @@ public class ReplayOperatorTest {
 
             // The round 2
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "sender1")), 1, 0);
+            // The output would be done asynchronously inside the ReplayerOperator.
             while (harness.getOutput().size() < numRecords + 1) {
                 Thread.sleep(500);
             }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
index 73a95b1..4935484 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
@@ -59,6 +59,23 @@ public class OperatorEpochWatermarkTrackerTest extends TestLogger {
         assertEquals(Collections.singletonList(3), recordingProgressListener.notifications);
     }
 
+    @Test
+    public void testFinish() throws IOException {
+        RecordingProgressListener recordingProgressListener = new RecordingProgressListener();
+        int[] numberOfChannels = new int[] {2, 3};
+        OperatorEpochWatermarkTracker progressTracker =
+                new OperatorEpochWatermarkTracker(numberOfChannels, recordingProgressListener);
+        progressTracker.finish(0);
+
+        testOnEpochWatermark(
+                new int[] {0, 0, 1},
+                progressTracker,
+                recordingProgressListener,
+                new int[] {1, 1, 1},
+                new String[] {"1-1", "1-2", "1-3"},
+                3);
+    }
+
     private void testOnEpochWatermark(
             int[] expectedNumNotifications,
             OperatorEpochWatermarkTracker tracker,

[flink-ml] 05/08: [FLINK-24655][iteration] Support the checkpoints for the iteration

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 4b47e5df79b35336c4271f9dce1cbac9781451ef
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 00:39:49 2021 +0800

    [FLINK-24655][iteration] Support the checkpoints for the iteration
---
 .../flink/iteration/checkpoint/Checkpoints.java    |  77 +++-
 .../iteration/checkpoint/CheckpointsBroker.java    |  57 +++
 .../flink/iteration/operator/HeadOperator.java     | 249 ++++++++++-
 .../operator/HeadOperatorCheckpointAligner.java    |  26 ++
 .../iteration/operator/HeadOperatorFactory.java    |   8 +
 ...dOperatorState.java => OperatorStateUtils.java} |  26 +-
 .../flink/iteration/operator/TailOperator.java     |  27 ++
 .../coordinator/HeadOperatorCoordinator.java       |  10 +-
 .../coordinator/SharedProgressAligner.java         |  28 ++
 .../headprocessor/HeadOperatorRecordProcessor.java |   7 +-
 .../operator/headprocessor/HeadOperatorState.java  |  32 +-
 .../RegularHeadOperatorRecordProcessor.java        | 110 ++++-
 .../TerminatingHeadOperatorRecordProcessor.java    |   7 +-
 .../flink/iteration/IterationConstructionTest.java |  38 ++
 .../flink/iteration/operator/HeadOperatorTest.java | 496 ++++++++++++++++++++-
 15 files changed, 1154 insertions(+), 44 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
index 03420f8..edbfeba 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
@@ -19,11 +19,13 @@
 package org.apache.flink.iteration.checkpoint;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
 import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
 import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.ResourceGuard;
 import org.apache.flink.util.function.SupplierWithException;
 
@@ -33,6 +35,9 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.util.SortedMap;
 import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** Maintains the pending checkpoints. */
 public class Checkpoints<T> implements AutoCloseable {
@@ -43,7 +48,10 @@ public class Checkpoints<T> implements AutoCloseable {
     private final FileSystem fileSystem;
     private final SupplierWithException<Path, IOException> pathSupplier;
 
-    private final TreeMap<Long, PendingCheckpoint> uncompletedCheckpoints = new TreeMap<>();
+    private final ConcurrentHashMap<Long, Tuple2<PendingCheckpoint, Boolean>>
+            uncompletedCheckpoints = new ConcurrentHashMap<>();
+
+    private final TreeMap<Long, PendingCheckpoint> sortedUncompletedCheckpoints = new TreeMap<>();
 
     public Checkpoints(
             TypeSerializer<T> typeSerializer,
@@ -51,27 +59,72 @@ public class Checkpoints<T> implements AutoCloseable {
             SupplierWithException<Path, IOException> pathSupplier) {
         this.typeSerializer = typeSerializer;
         this.fileSystem = fileSystem;
+        checkState(!fileSystem.isDistributedFS(), "Currently only local fs is supported");
         this.pathSupplier = pathSupplier;
     }
 
+    public TypeSerializer<T> getTypeSerializer() {
+        return typeSerializer;
+    }
+
+    public FileSystem getFileSystem() {
+        return fileSystem;
+    }
+
+    public SupplierWithException<Path, IOException> getPathSupplier() {
+        return pathSupplier;
+    }
+
     public void startLogging(long checkpointId, OperatorStateCheckpointOutputStream outputStream)
             throws IOException {
-        DataCacheWriter<T> dataCacheWriter =
-                new DataCacheWriter<>(typeSerializer, fileSystem, pathSupplier);
-        ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
-        uncompletedCheckpoints.put(
-                checkpointId, new PendingCheckpoint(dataCacheWriter, outputStream, snapshotLease));
+        Tuple2<PendingCheckpoint, Boolean> possibleCheckpoint =
+                uncompletedCheckpoints.computeIfAbsent(
+                        checkpointId,
+                        ignored -> {
+                            try {
+                                DataCacheWriter<T> dataCacheWriter =
+                                        new DataCacheWriter<>(
+                                                typeSerializer, fileSystem, pathSupplier);
+                                ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
+                                return new Tuple2<>(
+                                        new PendingCheckpoint(
+                                                dataCacheWriter, outputStream, snapshotLease),
+                                        false);
+                            } catch (IOException e) {
+                                throw new FlinkRuntimeException(e);
+                            }
+                        });
+
+        // If canceled, return
+        if (possibleCheckpoint.f1) {
+            return;
+        }
+
+        sortedUncompletedCheckpoints.put(checkpointId, possibleCheckpoint.f0);
+    }
+
+    public void abort(long checkpointId) {
+        uncompletedCheckpoints.compute(
+                checkpointId,
+                (k, v) -> {
+                    if (v == null) {
+                        return new Tuple2<>(null, true);
+                    } else {
+                        v.f0.snapshotLease.close();
+                        return new Tuple2<>(v.f0, true);
+                    }
+                });
     }
 
     public void append(T element) throws IOException {
-        for (PendingCheckpoint pendingCheckpoint : uncompletedCheckpoints.values()) {
+        for (PendingCheckpoint pendingCheckpoint : sortedUncompletedCheckpoints.values()) {
             pendingCheckpoint.dataCacheWriter.addRecord(element);
         }
     }
 
     public void commitCheckpointsUntil(long checkpointId) {
         SortedMap<Long, PendingCheckpoint> completedCheckpoints =
-                uncompletedCheckpoints.headMap(checkpointId, true);
+                sortedUncompletedCheckpoints.headMap(checkpointId, true);
         completedCheckpoints
                 .values()
                 .forEach(
@@ -86,9 +139,13 @@ public class Checkpoints<T> implements AutoCloseable {
                                                         .getFinishSegments());
                                 pendingCheckpoint.checkpointOutputStream.startNewPartition();
                                 snapshot.writeTo(pendingCheckpoint.checkpointOutputStream);
+
+                                // Directly cleanup all the files since we are using the local fs.
+                                // TODO: support of the remote fs.
                                 pendingCheckpoint.dataCacheWriter.cleanup();
                             } catch (Exception e) {
                                 LOG.error("Failed to commit checkpoint until " + checkpointId, e);
+                                throw new FlinkRuntimeException(e);
                             } finally {
                                 pendingCheckpoint.snapshotLease.close();
                             }
@@ -99,7 +156,7 @@ public class Checkpoints<T> implements AutoCloseable {
 
     @Override
     public void close() {
-        uncompletedCheckpoints.forEach(
+        sortedUncompletedCheckpoints.forEach(
                 (checkpointId, pendingCheckpoint) -> {
                     pendingCheckpoint.snapshotLease.close();
                     try {
@@ -108,10 +165,12 @@ public class Checkpoints<T> implements AutoCloseable {
                         LOG.error("Failed to cleanup " + checkpointId, e);
                     }
                 });
+        sortedUncompletedCheckpoints.clear();
         uncompletedCheckpoints.clear();
     }
 
     private class PendingCheckpoint {
+
         final DataCacheWriter<T> dataCacheWriter;
 
         final OperatorStateCheckpointOutputStream checkpointOutputStream;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
new file mode 100644
index 0000000..e969c10
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.checkpoint;
+
+import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
+
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Hand offs the {@link Checkpoints} from the head operator to the tail operator so that the tail
+ * operator could decrease the reference count of the raw state when checkpoints are aborted. We
+ * could not count on the head operator since it would be blocked on closing the raw state when
+ * aborting the checkpoint. It also looks like a bug.
+ */
+public class CheckpointsBroker {
+
+    private static final CheckpointsBroker INSTANCE = new CheckpointsBroker();
+
+    private final ConcurrentHashMap<SubtaskFeedbackKey<?>, Checkpoints<?>> checkpointManagers =
+            new ConcurrentHashMap<>();
+
+    public static CheckpointsBroker get() {
+        return INSTANCE;
+    }
+
+    public <V> void setCheckpoints(SubtaskFeedbackKey<V> key, Checkpoints<V> checkpoints) {
+        checkpointManagers.put(key, checkpoints);
+    }
+
+    @SuppressWarnings({"unchecked"})
+    public <V> Checkpoints<V> getCheckpoints(SubtaskFeedbackKey<V> key) {
+        Objects.requireNonNull(key);
+        return (Checkpoints<V>) Objects.requireNonNull(checkpointManagers.get(key));
+    }
+
+    @SuppressWarnings("resource")
+    void removeChannel(SubtaskFeedbackKey<?> key) {
+        checkpointManagers.remove(key);
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index 6e5a7b4..5306d2b 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -21,23 +21,46 @@ package org.apache.flink.iteration.operator;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
 import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
+import org.apache.flink.iteration.operator.headprocessor.HeadOperatorState;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.operator.headprocessor.TerminatingHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.iteration.utils.ReflectionUtils;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferConsumerWithPartialRecordLength;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
+import org.apache.flink.runtime.io.network.partition.PrioritizedDeque;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
 import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
@@ -57,6 +80,9 @@ import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.OutputTag;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.Executor;
 
@@ -112,6 +138,16 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     private HeadOperatorCheckpointAligner checkpointAligner;
 
+    // ------------- states -------------------
+
+    private ListState<Integer> parallelismState;
+
+    private ListState<Integer> statusState;
+
+    private ListState<HeadOperatorState> processorState;
+
+    private Checkpoints<IterationRecord<?>> checkpoints;
+
     public HeadOperator(
             IterationID iterationId,
             int feedbackIndex,
@@ -146,12 +182,87 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     public void initializeState(StateInitializationContext context) throws Exception {
         super.initializeState(context);
 
+        parallelismState =
+                context.getOperatorStateStore()
+                        .getUnionListState(
+                                new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+                .ifPresent(
+                        oldParallelism ->
+                                checkState(
+                                        oldParallelism
+                                                == getRuntimeContext()
+                                                        .getNumberOfParallelSubtasks(),
+                                        "The head operator is recovered with parallelism changed from "
+                                                + oldParallelism
+                                                + " to "
+                                                + getRuntimeContext()
+                                                        .getNumberOfParallelSubtasks()));
+
+        // Initialize the status and the record processor.
         processorContext = new ContextImpl();
-        status = HeadOperatorStatus.RUNNING;
-        recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+        statusState =
+                context.getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("status", Integer.class));
+        status =
+                HeadOperatorStatus.values()[
+                        OperatorStateUtils.getUniqueElement(statusState, "status").orElse(0)];
+        if (status == HeadOperatorStatus.RUNNING) {
+            recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+        } else {
+            recordProcessor = new TerminatingHeadOperatorRecordProcessor();
+        }
+
+        // Recover the process state if exists.
+        processorState =
+                context.getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "processorState", HeadOperatorState.class));
+        OperatorStateUtils.getUniqueElement(processorState, "processorState")
+                .ifPresent(
+                        headOperatorState ->
+                                recordProcessor.initializeState(
+                                        headOperatorState, context.getRawOperatorStateInputs()));
 
         checkpointAligner = new HeadOperatorCheckpointAligner();
 
+        // Initialize the checkpoints
+        Path dataCachePath =
+                OperatorUtils.getDataCachePath(
+                        getRuntimeContext().getTaskManagerRuntimeInfo().getConfiguration(),
+                        getContainingTask()
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+        this.checkpoints =
+                new Checkpoints<>(
+                        config.getTypeSerializerOut(getClass().getClassLoader()),
+                        dataCachePath.getFileSystem(),
+                        OperatorUtils.createDataCacheFileGenerator(
+                                dataCachePath, "header-cp", getOperatorConfig().getOperatorID()));
+        CheckpointsBroker.get()
+                .setCheckpoints(
+                        OperatorUtils.<IterationRecord<?>>createFeedbackKey(
+                                        iterationId, feedbackIndex)
+                                .withSubTaskIndex(
+                                        getRuntimeContext().getIndexOfThisSubtask(),
+                                        getRuntimeContext().getAttemptNumber()),
+                        checkpoints);
+
+        try {
+            for (StatePartitionStreamProvider rawStateInput : context.getRawOperatorStateInputs()) {
+                DataCacheSnapshot.replay(
+                        rawStateInput.getStream(),
+                        checkpoints.getTypeSerializer(),
+                        checkpoints.getFileSystem(),
+                        (record) ->
+                                recordProcessor.processFeedbackElement(new StreamRecord<>(record)));
+            }
+        } catch (Exception e) {
+            throw new FlinkRuntimeException("Failed to replay the records", e);
+        }
+
         // Here we register a mail
         registerFeedbackConsumer(
                 (Runnable runnable) -> {
@@ -171,18 +282,54 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     @Override
     public void snapshotState(StateSnapshotContext context) throws Exception {
         super.snapshotState(context);
+
+        // Always clear the union list state before set value.
+        parallelismState.clear();
+        if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
+            parallelismState.update(
+                    Collections.singletonList(getRuntimeContext().getNumberOfParallelSubtasks()));
+        }
+        statusState.update(Collections.singletonList(status.ordinal()));
+
+        HeadOperatorState currentProcessorState = recordProcessor.snapshotState();
+        if (currentProcessorState != null) {
+            processorState.update(Collections.singletonList(currentProcessorState));
+        } else {
+            processorState.clear();
+        }
+
+        if (status == HeadOperatorStatus.RUNNING) {
+            checkpoints.startLogging(
+                    context.getCheckpointId(), context.getRawOperatorStateOutput());
+        }
+
         checkpointAligner
                 .onStateSnapshot(context.getCheckpointId())
                 .forEach(this::processGloballyAlignedEvent);
     }
 
     @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        super.notifyCheckpointAborted(checkpointId);
+
+        checkpointAligner
+                .onCheckpointAborted(checkpointId)
+                .forEach(this::processGloballyAlignedEvent);
+    }
+
+    @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
         recordProcessor.processElement(element);
     }
 
     @Override
     public void processFeedback(StreamRecord<IterationRecord<?>> iterationRecord) throws Exception {
+        if (iterationRecord.getValue().getType() == IterationRecord.Type.BARRIER) {
+            checkpoints.commitCheckpointsUntil(iterationRecord.getValue().getCheckpointId());
+            return;
+        }
+
+        checkpoints.append(iterationRecord.getValue());
         boolean terminated = recordProcessor.processFeedbackElement(iterationRecord);
         if (terminated) {
             checkState(status == HeadOperatorStatus.TERMINATING);
@@ -213,13 +360,65 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     @Override
     public void endInput() throws Exception {
-        recordProcessor.processElement(
-                new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+        if (status == HeadOperatorStatus.RUNNING) {
+            recordProcessor.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+        }
+
+        // Since we choose to block here, we could not continue to process the barriers received
+        // from the task inputs, which would block the precedent tasks from finishing since
+        // they need to complete their final checkpoint. This is a temporary solution to this issue
+        // that we will check the input channels, trigger all the checkpoints until we see
+        // the EndOfPartitionEvent.
+        checkState(getContainingTask().getEnvironment().getAllInputGates().length == 1);
+        checkState(
+                getContainingTask()
+                                .getEnvironment()
+                                .getAllInputGates()[0]
+                                .getNumberOfInputChannels()
+                        == 1);
+        InputChannel inputChannel =
+                getContainingTask().getEnvironment().getAllInputGates()[0].getChannel(0);
+
+        boolean endOfPartitionReceived = false;
+        long lastTriggerCheckpointId = 0;
+        while (!endOfPartitionReceived && status != HeadOperatorStatus.TERMINATED) {
+            mailboxExecutor.tryYield();
+            Thread.sleep(200);
+
+            List<AbstractEvent> events = parseInputChannelEvents(inputChannel);
+
+            for (AbstractEvent event : events) {
+                if (event instanceof CheckpointBarrier) {
+                    CheckpointBarrier barrier = (CheckpointBarrier) event;
+                    if (barrier.getId() > lastTriggerCheckpointId) {
+                        getContainingTask()
+                                .triggerCheckpointAsync(
+                                        new CheckpointMetaData(
+                                                barrier.getId(), barrier.getTimestamp()),
+                                        barrier.getCheckpointOptions());
+                        lastTriggerCheckpointId = barrier.getId();
+                    }
+
+                } else if (event instanceof EndOfPartitionEvent) {
+                    endOfPartitionReceived = true;
+                }
+            }
+        }
+
+        // By here we could step into the normal loop.
         while (status != HeadOperatorStatus.TERMINATED) {
             mailboxExecutor.yield();
         }
     }
 
+    @Override
+    public void close() throws Exception {
+        if (checkpoints != null) {
+            checkpoints.close();
+        }
+    }
+
     private void registerFeedbackConsumer(Executor mailboxExecutor) {
         int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
         int attemptNum = getRuntimeContext().getAttemptNumber();
@@ -232,6 +431,48 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         OperatorUtils.registerFeedbackConsumer(channel, this, mailboxExecutor);
     }
 
+    private List<AbstractEvent> parseInputChannelEvents(InputChannel inputChannel)
+            throws Exception {
+        List<AbstractEvent> events = new ArrayList<>();
+        if (inputChannel instanceof RemoteInputChannel) {
+            Class<?> seqBufferClass =
+                    Class.forName(
+                            "org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel$SequenceBuffer");
+            PrioritizedDeque<?> queue =
+                    ReflectionUtils.getFieldValue(
+                            inputChannel, RemoteInputChannel.class, "receivedBuffers");
+            for (Object sequenceBuffer : queue) {
+                Buffer buffer =
+                        ReflectionUtils.getFieldValue(sequenceBuffer, seqBufferClass, "buffer");
+                if (!buffer.isBuffer()) {
+                    events.add(EventSerializer.fromBuffer(buffer, getClass().getClassLoader()));
+                }
+            }
+        } else if (inputChannel instanceof LocalInputChannel) {
+            PipelinedSubpartitionView subpartitionView =
+                    ReflectionUtils.getFieldValue(
+                            inputChannel, LocalInputChannel.class, "subpartitionView");
+            PipelinedSubpartition pipelinedSubpartition =
+                    ReflectionUtils.getFieldValue(
+                            subpartitionView, PipelinedSubpartitionView.class, "parent");
+            PrioritizedDeque<BufferConsumerWithPartialRecordLength> queue =
+                    ReflectionUtils.getFieldValue(
+                            pipelinedSubpartition, PipelinedSubpartition.class, "buffers");
+            for (BufferConsumerWithPartialRecordLength bufferConsumer : queue) {
+                if (!bufferConsumer.getBufferConsumer().isBuffer()) {
+                    events.add(
+                            EventSerializer.fromBuffer(
+                                    bufferConsumer.getBufferConsumer().copy().build(),
+                                    getClass().getClassLoader()));
+                }
+            }
+        } else {
+            LOG.warn("Unknown input channel type: " + inputChannel);
+        }
+
+        return events;
+    }
+
     @VisibleForTesting
     public OperatorEventGateway getOperatorEventGateway() {
         return operatorEventGateway;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
index b4f4215..9b83a7d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -24,6 +24,7 @@ import org.apache.flink.util.function.RunnableWithException;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.TreeMap;
 
@@ -40,6 +41,8 @@ class HeadOperatorCheckpointAligner {
 
     private long latestCheckpointFromCoordinator;
 
+    private long latestAbortedCheckpoint;
+
     HeadOperatorCheckpointAligner() {
         this.checkpointAlignmments = new TreeMap<>();
     }
@@ -58,6 +61,13 @@ class HeadOperatorCheckpointAligner {
     void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
         checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
         latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+
+        // Do nothing if later checkpoint is aborted. In this case there should not be
+        // the notification from the task side.
+        if (latestCheckpointFromCoordinator <= latestAbortedCheckpoint) {
+            return;
+        }
+
         CheckpointAlignment checkpointAlignment =
                 checkpointAlignmments.computeIfAbsent(
                         checkpointEvent.getCheckpointId(),
@@ -86,6 +96,22 @@ class HeadOperatorCheckpointAligner {
         return checkpointAlignment.pendingGlobalEvents;
     }
 
+    List<GloballyAlignedEvent> onCheckpointAborted(long checkpointId) {
+        // Here we need to abort all the checkpoints <= notified checkpoint id.
+        checkState(checkpointId > latestAbortedCheckpoint);
+        latestAbortedCheckpoint = checkpointId;
+
+        Map<Long, CheckpointAlignment> abortedAlignments =
+                checkpointAlignmments.headMap(latestAbortedCheckpoint, true);
+        List<GloballyAlignedEvent> events = new ArrayList<>();
+        abortedAlignments
+                .values()
+                .forEach(alignment -> events.addAll(alignment.pendingGlobalEvents));
+        abortedAlignments.clear();
+
+        return events;
+    }
+
     private static class CheckpointAlignment {
 
         final List<GloballyAlignedEvent> pendingGlobalEvents;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
index 2bd1e38..5040506 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.CoordinatedOperatorFactory;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -132,4 +133,11 @@ public class HeadOperatorFactory extends AbstractStreamOperatorFactory<Iteration
         // We need it to be yielding operator factory to disable chaining,
         // but we cannot use the given mailbox here since it has bugs.
     }
+
+    @Override
+    public ChainingStrategy getChainingStrategy() {
+        // We could not allow the head operator chaining with the previous operator since
+        // the special treatment in endInput.
+        return ChainingStrategy.HEAD;
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
similarity index 51%
copy from flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
copy to flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
index 861f481..fc4cd07 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
@@ -16,7 +16,27 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.headprocessor;
+package org.apache.flink.iteration.operator;
 
-/** The state entry for the head operator. */
-public class HeadOperatorState {}
+import org.apache.flink.api.common.state.ListState;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Utility to deal with the states inside the operator. */
+public class OperatorStateUtils {
+
+    public static <T> Optional<T> getUniqueElement(ListState<T> listState, String stateName)
+            throws Exception {
+        Iterator<T> iterator = listState.get().iterator();
+        if (!iterator.hasNext()) {
+            return Optional.empty();
+        }
+
+        T result = iterator.next();
+        checkState(!iterator.hasNext(), "The state " + stateName + " has more that one elements");
+        return Optional.of(result);
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
index a58155e..2e26142 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
@@ -20,6 +20,8 @@ package org.apache.flink.iteration.operator;
 
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
@@ -90,6 +92,31 @@ public class TailOperator extends AbstractStreamOperator<Void>
         recordConsumer.accept(streamRecord);
     }
 
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        super.prepareSnapshotPreBarrier(checkpointId);
+        channel.put(new StreamRecord<>(IterationRecord.newBarrier(checkpointId)));
+    }
+
+    @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        super.notifyCheckpointAborted(checkpointId);
+
+        // TODO: Unfortunately, we have to rely on the tail operator to help
+        // abort the checkpoint since the task thread of the head operator
+        // might get blocked due to not be able to close the raw state files.
+        // We would try to fix it in the Flink side in the future.
+        SubtaskFeedbackKey<?> key =
+                OperatorUtils.createFeedbackKey(iterationId, feedbackIndex)
+                        .withSubTaskIndex(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                getRuntimeContext().getAttemptNumber());
+        Checkpoints<?> checkpoints = CheckpointsBroker.get().getCheckpoints(key);
+        if (checkpoints != null) {
+            checkpoints.abort(checkpointId);
+        }
+    }
+
     private void processIfObjectReuseEnabled(StreamRecord<IterationRecord<?>> record) {
         // Since the record would be reused, we have to clone a new one
         IterationRecord<?> cloned = record.getValue().clone();
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index 2c01c33..f4d6e2d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -63,10 +63,16 @@ public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgr
     }
 
     @Override
-    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {
+        for (int i = 0; i < context.currentParallelism(); ++i) {
+            sharedProgressAligner.removeProgressInfo(context.getOperatorId());
+        }
+    }
 
     @Override
-    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {
+        sharedProgressAligner.removeProgressInfo(context.getOperatorId(), subtaskIndex);
+    }
 
     @Override
     public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 9c2fd43..4c2204d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -182,6 +182,24 @@ public class SharedProgressAligner {
                 checkpointId);
     }
 
+    public void removeProgressInfo(OperatorID operatorId) {
+        runInEventLoop(
+                () -> statusByEpoch.values().forEach(status -> status.remove(operatorId)),
+                "remove the progress information for {}",
+                operatorId);
+    }
+
+    public void removeProgressInfo(OperatorID operatorId, int subtaskIndex) {
+        runInEventLoop(
+                () ->
+                        statusByEpoch
+                                .values()
+                                .forEach(status -> status.remove(operatorId, subtaskIndex)),
+                "remove the progress information for {}-{}",
+                operatorId,
+                subtaskIndex);
+    }
+
     private void runInEventLoop(
             ThrowingRunnable<Throwable> action,
             String actionName,
@@ -234,6 +252,16 @@ public class SharedProgressAligner {
             return reportedSubtasks.size() == totalHeadParallelism;
         }
 
+        public void remove(OperatorID operatorID) {
+            reportedSubtasks
+                    .entrySet()
+                    .removeIf(entry -> entry.getKey().getOperatorId().equals(operatorID));
+        }
+
+        public void remove(OperatorID operatorID, int subtaskIndex) {
+            reportedSubtasks.remove(new OperatorInstanceID(subtaskIndex, operatorID));
+        }
+
         public boolean isTerminated() {
             checkState(
                     reportedSubtasks.size() == totalHeadParallelism,
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
index 78132c3..26d397d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
@@ -22,14 +22,18 @@ import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.HeadOperator;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.OutputTag;
 
+import javax.annotation.Nullable;
+
 /** The component to actually deal with the event received in the {@link HeadOperator}. */
 public interface HeadOperatorRecordProcessor {
 
-    void initializeState(HeadOperatorState headOperatorState) throws Exception;
+    void initializeState(
+            HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates);
 
     void processElement(StreamRecord<IterationRecord<?>> record);
 
@@ -37,6 +41,7 @@ public interface HeadOperatorRecordProcessor {
 
     boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent);
 
+    @Nullable
     HeadOperatorState snapshotState();
 
     /** The context for {@link HeadOperatorRecordProcessor}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
index 861f481..996d929 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
@@ -18,5 +18,35 @@
 
 package org.apache.flink.iteration.operator.headprocessor;
 
+import java.util.Map;
+
 /** The state entry for the head operator. */
-public class HeadOperatorState {}
+public class HeadOperatorState {
+
+    private Map<Integer, Long> numFeedbackRecordsEachRound;
+
+    private int latestRoundAligned;
+
+    private int latestRoundGloballyAligned;
+
+    public HeadOperatorState(
+            Map<Integer, Long> numFeedbackRecordsEachRound,
+            int latestRoundAligned,
+            int latestRoundGloballyAligned) {
+        this.numFeedbackRecordsEachRound = numFeedbackRecordsEachRound;
+        this.latestRoundAligned = latestRoundAligned;
+        this.latestRoundGloballyAligned = latestRoundGloballyAligned;
+    }
+
+    public Map<Integer, Long> getNumFeedbackRecordsEachRound() {
+        return numFeedbackRecordsEachRound;
+    }
+
+    public int getLatestRoundAligned() {
+        return latestRoundAligned;
+    }
+
+    public int getLatestRoundGloballyAligned() {
+        return latestRoundGloballyAligned;
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
index f1a6b0f..107a233 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 import org.slf4j.Logger;
@@ -30,6 +31,9 @@ import org.slf4j.LoggerFactory;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
 /**
  * Processes the event before we received the terminated global aligned event from the coordinator.
  */
@@ -40,26 +44,46 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
 
     private final Context headOperatorContext;
 
-    private final StreamRecord<IterationRecord<?>> reusable;
-
     private final Map<Integer, Long> numFeedbackRecordsPerEpoch;
 
     private final String senderId;
 
+    private int latestRoundAligned;
+
+    private int latestRoundGloballyAligned;
+
     public RegularHeadOperatorRecordProcessor(Context headOperatorContext) {
         this.headOperatorContext = headOperatorContext;
 
-        this.reusable = new StreamRecord<>(null);
         this.numFeedbackRecordsPerEpoch = new HashMap<>();
 
         this.senderId =
                 OperatorUtils.getUniqueSenderId(
                         headOperatorContext.getStreamConfig().getOperatorID(),
                         headOperatorContext.getTaskInfo().getIndexOfThisSubtask());
+
+        this.latestRoundAligned = -1;
+        this.latestRoundGloballyAligned = -1;
     }
 
     @Override
-    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+    public void initializeState(
+            HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates) {
+        checkArgument(headOperatorState != null, "The initialized state should not be null");
+
+        numFeedbackRecordsPerEpoch.putAll(headOperatorState.getNumFeedbackRecordsEachRound());
+        latestRoundAligned = headOperatorState.getLatestRoundAligned();
+        latestRoundGloballyAligned = headOperatorState.getLatestRoundGloballyAligned();
+
+        // If the only round not fully aligned is round 0, then wait till endOfInput in
+        // case the input is changed.
+        if (!(latestRoundAligned == 0 && latestRoundGloballyAligned == -1)) {
+            for (int i = latestRoundGloballyAligned + 1; i <= latestRoundAligned; ++i) {
+                headOperatorContext.updateEpochToCoordinator(
+                        i, numFeedbackRecordsPerEpoch.getOrDefault(i, 0L));
+            }
+        }
+    }
 
     @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) {
@@ -81,22 +105,34 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
     @Override
     public boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent) {
         LOG.info("Received global event {}", globallyAlignedEvent);
-
-        reusable.replace(
-                IterationRecord.newEpochWatermark(
-                        globallyAlignedEvent.isTerminated()
-                                ? Integer.MAX_VALUE
-                                : globallyAlignedEvent.getEpoch(),
-                        senderId),
-                0);
-        headOperatorContext.broadcastOutput(reusable);
-
+        checkState(
+                (globallyAlignedEvent.getEpoch() == 0 && latestRoundGloballyAligned == 0)
+                        || globallyAlignedEvent.getEpoch() > latestRoundGloballyAligned,
+                String.format(
+                        "Receive unexpected global aligned event, latest = %d, this one = %d",
+                        latestRoundGloballyAligned, globallyAlignedEvent.getEpoch()));
+
+        StreamRecord<IterationRecord<?>> record =
+                new StreamRecord<>(
+                        IterationRecord.newEpochWatermark(
+                                globallyAlignedEvent.isTerminated()
+                                        ? Integer.MAX_VALUE
+                                        : globallyAlignedEvent.getEpoch(),
+                                senderId),
+                        0);
+        headOperatorContext.broadcastOutput(record);
+
+        latestRoundGloballyAligned =
+                Math.max(globallyAlignedEvent.getEpoch(), latestRoundGloballyAligned);
         return globallyAlignedEvent.isTerminated();
     }
 
     @Override
     public HeadOperatorState snapshotState() {
-        return new HeadOperatorState();
+        return new HeadOperatorState(
+                new HashMap<>(numFeedbackRecordsPerEpoch),
+                latestRoundAligned,
+                latestRoundGloballyAligned);
     }
 
     @VisibleForTesting
@@ -104,18 +140,50 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
         return numFeedbackRecordsPerEpoch;
     }
 
+    @VisibleForTesting
+    public int getLatestRoundAligned() {
+        return latestRoundAligned;
+    }
+
+    @VisibleForTesting
+    public int getLatestRoundGloballyAligned() {
+        return latestRoundGloballyAligned;
+    }
+
     private void processRecord(StreamRecord<IterationRecord<?>> iterationRecord) {
         switch (iterationRecord.getValue().getType()) {
             case RECORD:
-                reusable.replace(iterationRecord.getValue(), iterationRecord.getTimestamp());
-                headOperatorContext.output(reusable);
+                headOperatorContext.output(iterationRecord);
                 break;
             case EPOCH_WATERMARK:
                 LOG.info("Head Received epoch watermark {}", iterationRecord.getValue().getEpoch());
-                headOperatorContext.updateEpochToCoordinator(
-                        iterationRecord.getValue().getEpoch(),
-                        numFeedbackRecordsPerEpoch.getOrDefault(
-                                iterationRecord.getValue().getEpoch(), 0L));
+
+                boolean needNotifyCoordinator = false;
+                if (iterationRecord.getValue().getEpoch() == 0) {
+                    if (latestRoundAligned <= 0) {
+                        needNotifyCoordinator = true;
+                    }
+                } else {
+                    checkState(
+                            iterationRecord.getValue().getEpoch() > latestRoundAligned,
+                            String.format(
+                                    "Unexpected epoch watermark: latest = %d, this one = %d",
+                                    latestRoundAligned, iterationRecord.getValue().getEpoch()));
+                    headOperatorContext.updateEpochToCoordinator(
+                            iterationRecord.getValue().getEpoch(),
+                            numFeedbackRecordsPerEpoch.getOrDefault(
+                                    iterationRecord.getValue().getEpoch(), 0L));
+                }
+
+                if (needNotifyCoordinator) {
+                    headOperatorContext.updateEpochToCoordinator(
+                            iterationRecord.getValue().getEpoch(),
+                            numFeedbackRecordsPerEpoch.getOrDefault(
+                                    iterationRecord.getValue().getEpoch(), 0L));
+                }
+
+                latestRoundAligned =
+                        Math.max(iterationRecord.getValue().getEpoch(), latestRoundAligned);
                 break;
         }
     }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
index c5377e5..f18bfa3 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.headprocessor;
 
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.FlinkRuntimeException;
 
@@ -30,7 +31,9 @@ import org.apache.flink.util.FlinkRuntimeException;
 public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
 
     @Override
-    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+    public void initializeState(
+            HeadOperatorState headOperatorState,
+            Iterable<StatePartitionStreamProvider> rawStates) {}
 
     @Override
     public void processElement(StreamRecord<IterationRecord<?>> record) {
@@ -55,6 +58,6 @@ public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecor
 
     @Override
     public HeadOperatorState snapshotState() {
-        return new HeadOperatorState();
+        return null;
     }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index 9fa0d28..3cd4bf2 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -79,6 +79,44 @@ public class IterationConstructionTest extends TestLogger {
     }
 
     @Test
+    public void testNotChainingHeadOperator() {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(4);
+        DataStream<Integer> variableSource =
+                env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
+                        .name("Variable")
+                        .map(x -> x)
+                        .name("map")
+                        .setParallelism(2);
+        DataStreamList result =
+                Iterations.iterateUnboundedStreams(
+                        DataStreamList.of(variableSource),
+                        DataStreamList.of(),
+                        ((variableStreams, dataStreams) ->
+                                new IterationBodyResult(variableStreams, dataStreams)));
+
+        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+        List<String> expectedVertexNames =
+                Arrays.asList(
+                        /* 0 */ "Source: Variable",
+                        /* 1 */ "map -> input-map",
+                        /* 2 */ "head-map",
+                        /* 3 */ "tail-head-map");
+        List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2);
+
+        List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+        assertEquals(
+                expectedVertexNames,
+                vertices.stream().map(JobVertex::getName).collect(Collectors.toList()));
+        assertEquals(
+                expectedParallelisms,
+                vertices.stream().map(JobVertex::getParallelism).collect(Collectors.toList()));
+        assertNotNull(vertices.get(2).getCoLocationGroup());
+        assertSame(vertices.get(2).getCoLocationGroup(), vertices.get(3).getCoLocationGroup());
+    }
+
+    @Test
     public void testUnboundedIteration() {
         StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         DataStream<Integer> variableSource1 =
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index e6d1e26..5603cc9 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
@@ -43,6 +44,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
@@ -56,7 +58,10 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
@@ -157,7 +162,10 @@ public class HeadOperatorTest extends TestLogger {
 
                                             while (RecordingHeadOperatorFactory.latestHeadOperator
                                                             .getStatus()
-                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {}
+                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {
+                                                Thread.sleep(500);
+                                            }
+
                                             putFeedbackRecords(
                                                     iterationId,
                                                     IterationRecord.newEpochWatermark(
@@ -312,6 +320,452 @@ public class HeadOperatorTest extends TestLogger {
                 });
     }
 
+    @Test
+    public void testSnapshotAndRestoreBeforeRoundZeroFinish() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            -1,
+                            -1);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneNotAligned() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+                            // Simulates endOfInputs, but not block the main thread.
+                            harness.processElement(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(0, "fake")));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            0,
+                            -1);
+
+                    // Simulates endOfInputs, but not block the main thread.
+                    harness.processElement(
+                            new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+                    assertEquals(
+                            Collections.singletonList(new SubtaskAlignedEvent(0, 0, false)),
+                            new ArrayList<>(
+                                    ((RecordingOperatorEventGateway)
+                                                    RecordingHeadOperatorFactory.latestHeadOperator
+                                                            .getOperatorEventGateway())
+                                            .operatorEvents));
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneAligned() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+                            // Simulates endOfInputs, but not block the main thread.
+                            harness.processElement(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(0, "fake")));
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(0, false));
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(100, 1), null);
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(1, "tail"),
+                                    null);
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.singletonList(new SubtaskAlignedEvent(1, 1, false)),
+                            Collections.singletonMap(1, 1L),
+                            1,
+                            0);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreWithFeedbackRecords() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(4, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(4, false));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(100, 5), null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(101, 5), null);
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(102, 5), null);
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(103, 6), null);
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Arrays.asList(
+                                    new StreamRecord<>(IterationRecord.newRecord(101, 5)),
+                                    new StreamRecord<>(IterationRecord.newRecord(102, 5)),
+                                    new StreamRecord<>(IterationRecord.newRecord(103, 6))),
+                            /* The one before checkpoint and the two after checkpoint */
+                            Collections.singletonList(new SubtaskAlignedEvent(5, 3, false)),
+                            new HashMap<Integer, Long>() {
+                                {
+                                    this.put(5, 3L);
+                                    this.put(6, 1L);
+                                }
+                            },
+                            5,
+                            4);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testCheckpointBeforeTerminated() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(4, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(4, false));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(5, true));
+                            harness.processAll();
+
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.singletonList(new SubtaskAlignedEvent(5, 0, false)),
+                            Collections.emptyMap(),
+                            5,
+                            4);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testCheckpointAfterTerminating() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(5, true));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.TERMINATING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            -1,
+                            -1);
+
+                    putFeedbackRecords(
+                            iterationId,
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE + 1, "tail"),
+                            null);
+                    harness.processEvent(EndOfData.INSTANCE);
+                    harness.finishProcessing();
+
+                    return null;
+                });
+    }
+
+    @Test(timeout = 20000)
+    public void testTailAbortPendingCheckpointIfHeadBlocked() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+                    dispatchOperatorEvent(harness, operatorId, new CoordinatorCheckpointEvent(2));
+                    harness.getStreamTask()
+                            .triggerCheckpointAsync(
+                                    new CheckpointMetaData(2, 1000),
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processAll();
+
+                    putFeedbackRecords(iterationId, IterationRecord.newRecord(100, 1), null);
+                    harness.processAll();
+
+                    // Simulates the tail operators help to abort the checkpoint
+                    CompletableFuture<Void> supplier =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            // Slightly postpone the execution till the head
+                                            // operator get blocked.
+                                            Thread.sleep(2000);
+
+                                            OneInputStreamOperatorTestHarness<
+                                                            IterationRecord<?>, Void>
+                                                    testHarness =
+                                                            new OneInputStreamOperatorTestHarness<>(
+                                                                    new TailOperator(
+                                                                            iterationId, 0));
+                                            testHarness.open();
+
+                                            testHarness.getOperator().notifyCheckpointAborted(2);
+                                        } catch (Exception e) {
+                                            throw new CompletionException(e);
+                                        }
+
+                                        return null;
+                                    });
+
+                    harness.getStreamTask().notifyCheckpointAbortAsync(2, 0);
+                    harness.processAll();
+
+                    supplier.get();
+
+                    return null;
+                });
+    }
+
     private <T> T createHarnessAndRun(
             IterationID iterationId,
             OperatorID operatorId,
@@ -371,6 +825,46 @@ public class HeadOperatorTest extends TestLogger {
                         : new StreamRecord<>(record, timestamp));
     }
 
+    private static void checkRestoredOperatorState(
+            StreamTaskMailboxTestHarness<?> harness,
+            HeadOperator.HeadOperatorStatus expectedStatus,
+            List<Object> expectedOutput,
+            List<OperatorEvent> expectedOperatorEvents,
+            Map<Integer, Long> expectedNumFeedbackRecords,
+            int expectedLastAligned,
+            int expectedLastGloballyAligned) {
+        HeadOperator headOperator = RecordingHeadOperatorFactory.latestHeadOperator;
+        assertEquals(expectedStatus, headOperator.getStatus());
+        assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
+        RecordingOperatorEventGateway eventGateway =
+                (RecordingOperatorEventGateway) headOperator.getOperatorEventGateway();
+        assertEquals(expectedOperatorEvents, new ArrayList<>(eventGateway.operatorEvents));
+
+        if (expectedStatus == HeadOperator.HeadOperatorStatus.RUNNING) {
+            RegularHeadOperatorRecordProcessor recordProcessor =
+                    (RegularHeadOperatorRecordProcessor) headOperator.getRecordProcessor();
+            assertEquals(
+                    expectedNumFeedbackRecords, recordProcessor.getNumFeedbackRecordsPerEpoch());
+            assertEquals(expectedLastAligned, recordProcessor.getLatestRoundAligned());
+            assertEquals(
+                    expectedLastGloballyAligned, recordProcessor.getLatestRoundGloballyAligned());
+        }
+    }
+
+    /**
+     * We have to manually cleanup the feedback channel due to not be able to set the attempt
+     * number.
+     */
+    private static void cleanupFeedbackChannel(IterationID iterationId) {
+        FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
+                FeedbackChannelBroker.get()
+                        .getChannel(
+                                OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
+                                                iterationId, 0)
+                                        .withSubTaskIndex(0, 0));
+        feedbackChannel.close();
+    }
+
     private static class RecordingOperatorEventGateway implements OperatorEventGateway {
 
         final BlockingQueue<OperatorEvent> operatorEvents = new LinkedBlockingQueue<>();

[flink-ml] 08/08: [FLINK-24655][iteration] Add ITCase for the checkpoint and failover

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit b9ee412b0951d13ff5cf3d610a818a35e503a949
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Nov 1 23:56:09 2021 +0800

    [FLINK-24655][iteration] Add ITCase for the checkpoint and failover
    
    This closes #17.
---
 .../iteration/BoundedAllRoundCheckpointTest.java   | 196 +++++++++++++++++++++
 .../iteration/UnboundedStreamIterationITCase.java  |   5 +-
 .../flink/test/iteration/operators/FailingMap.java |  45 +++++
 .../operators/ReduceAllRoundProcessFunction.java   |  55 +++++-
 .../test/iteration/operators/SequenceSource.java   |  40 ++++-
 .../TwoInputReduceAllRoundProcessFunction.java     |  16 +-
 6 files changed, 347 insertions(+), 10 deletions(-)

diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
new file mode 100644
index 0000000..d53e334
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/BoundedAllRoundCheckpointTest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.compile.DraftExecutionEnvironment;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.test.iteration.operators.EpochRecord;
+import org.apache.flink.test.iteration.operators.FailingMap;
+import org.apache.flink.test.iteration.operators.IncrementEpochMap;
+import org.apache.flink.test.iteration.operators.OutputRecord;
+import org.apache.flink.test.iteration.operators.SequenceSource;
+import org.apache.flink.test.iteration.operators.TwoInputReduceAllRoundProcessFunction;
+import org.apache.flink.testutils.junit.SharedObjects;
+import org.apache.flink.testutils.junit.SharedReference;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.test.iteration.UnboundedStreamIterationITCase.createMiniClusterConfiguration;
+import static org.junit.Assert.assertEquals;
+
+/** Tests checkpoints. */
+@RunWith(Parameterized.class)
+public class BoundedAllRoundCheckpointTest extends TestLogger {
+
+    @Rule public final SharedObjects sharedObjects = SharedObjects.create();
+
+    private SharedReference<List<OutputRecord<Integer>>> result;
+
+    @Parameterized.Parameter(0)
+    public int failoverCount;
+
+    @Parameterized.Parameter(1)
+    public boolean sync;
+
+    @Parameterized.Parameters(name = "failoverCount = {0}, sync = {1}")
+    public static Collection<Object[]> params() {
+        int[] failoverCounts = {1000, 4000, 8000, 15900};
+        boolean[] syncs = {true, false};
+
+        List<Object[]> result = new ArrayList<>();
+        for (int failoverCount : failoverCounts) {
+            for (boolean sync : syncs) {
+                result.add(new Object[] {failoverCount, sync});
+            }
+        }
+
+        return result;
+    }
+
+    @Before
+    public void setup() {
+        result = sharedObjects.add(new ArrayList<>());
+    }
+
+    @Test
+    public void testFailoverAndRestore() throws Exception {
+        try (MiniCluster miniCluster = new MiniCluster(createMiniClusterConfiguration(2, 2))) {
+            miniCluster.start();
+
+            // Create the test job
+            JobGraph jobGraph =
+                    createVariableAndConstantJobGraph(
+                            4, 1000, false, 0, sync, 4, failoverCount, new CollectSink(result));
+            miniCluster.executeJobBlocking(jobGraph);
+
+            Map<Integer, Tuple2<Integer, Integer>> roundsStat = new HashMap<>();
+            for (OutputRecord<Integer> output : result.get()) {
+                Tuple2<Integer, Integer> state =
+                        roundsStat.computeIfAbsent(
+                                output.getRound(), ignored -> new Tuple2<>(0, 0));
+                state.f0++;
+                state.f1 = output.getValue();
+            }
+
+            // 0 ~ 4 round and termination information
+            assertEquals(6, roundsStat.size());
+            for (int i = 0; i <= 4; ++i) {
+                // In this case we could only check the final result, the number of records is not
+                // deterministic.
+                assertEquals(4 * (0 + 999) * 1000 / 2, (int) roundsStat.get(i).f1);
+            }
+        }
+    }
+
+    static JobGraph createVariableAndConstantJobGraph(
+            int numSources,
+            int numRecordsPerSource,
+            boolean holdSource,
+            int period,
+            boolean sync,
+            int maxRound,
+            int failoverCount,
+            SinkFunction<OutputRecord<Integer>> sinkFunction) {
+        StreamExecutionEnvironment env =
+                StreamExecutionEnvironment.getExecutionEnvironment(
+                        new Configuration() {
+                            {
+                                this.set(
+                                        ExecutionCheckpointingOptions
+                                                .ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
+                                        true);
+                            }
+                        });
+        env.enableCheckpointing(500, CheckpointingMode.EXACTLY_ONCE);
+        env.setParallelism(1);
+        DataStream<EpochRecord> variableSource =
+                env.addSource(new DraftExecutionEnvironment.EmptySource<EpochRecord>() {})
+                        .setParallelism(numSources)
+                        .name("Variable");
+        DataStream<EpochRecord> constSource =
+                env.addSource(new SequenceSource(numRecordsPerSource, holdSource, period))
+                        .setParallelism(numSources)
+                        .name("Constant");
+        DataStreamList outputs =
+                Iterations.iterateUnboundedStreams(
+                        DataStreamList.of(variableSource),
+                        DataStreamList.of(constSource),
+                        (variableStreams, dataStreams) -> {
+                            SingleOutputStreamOperator<EpochRecord> reducer =
+                                    variableStreams
+                                            .<EpochRecord>get(0)
+                                            .connect(dataStreams.<EpochRecord>get(0))
+                                            .process(
+                                                    new TwoInputReduceAllRoundProcessFunction(
+                                                            sync, maxRound));
+                            DataStream<EpochRecord> failedMap =
+                                    reducer.map(new FailingMap(failoverCount));
+                            return new IterationBodyResult(
+                                    DataStreamList.of(
+                                            failedMap
+                                                    .map(new IncrementEpochMap())
+                                                    .setParallelism(numSources)),
+                                    DataStreamList.of(
+                                            reducer.getSideOutput(
+                                                    new OutputTag<OutputRecord<Integer>>(
+                                                            "output") {})));
+                        });
+        outputs.<OutputRecord<Integer>>get(0).addSink(sinkFunction);
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class CollectSink implements SinkFunction<OutputRecord<Integer>> {
+
+        private final SharedReference<List<OutputRecord<Integer>>> result;
+
+        private CollectSink(SharedReference<List<OutputRecord<Integer>>> result) {
+            this.result = result;
+        }
+
+        @Override
+        public void invoke(OutputRecord<Integer> value, Context context) throws Exception {
+            result.get().add(value);
+        }
+    }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
index df084eb..6d80f23 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/UnboundedStreamIterationITCase.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.minicluster.MiniCluster;
 import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.test.iteration.operators.CollectSink;
 import org.apache.flink.test.iteration.operators.EpochRecord;
@@ -150,9 +151,11 @@ public class UnboundedStreamIterationITCase extends TestLogger {
         assertEquals(OutputRecord.Event.TERMINATED, result.get().take().getEvent());
     }
 
-    static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
+    public static MiniClusterConfiguration createMiniClusterConfiguration(int numTm, int numSlot) {
         Configuration configuration = new Configuration();
         configuration.set(RestOptions.BIND_PORT, "18081-19091");
+        configuration.set(
+                ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
         return new MiniClusterConfiguration.Builder()
                 .setConfiguration(configuration)
                 .setNumTaskManagers(numTm)
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
new file mode 100644
index 0000000..6c38d32
--- /dev/null
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/FailingMap.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.iteration.operators;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+
+/** Map Function triggers failover at the first task and first round. */
+public class FailingMap extends RichMapFunction<EpochRecord, EpochRecord> {
+
+    private final int failingCount;
+
+    private int count;
+
+    public FailingMap(int failingCount) {
+        this.failingCount = failingCount;
+    }
+
+    @Override
+    public EpochRecord map(EpochRecord value) throws Exception {
+        count++;
+        if (getRuntimeContext().getIndexOfThisSubtask() == 0
+                && getRuntimeContext().getAttemptNumber() == 0
+                && count >= failingCount) {
+            throw new RuntimeException("Artificial Exception");
+        }
+
+        return value;
+    }
+}
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
index dfce6e6..18d04d5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/ReduceAllRoundProcessFunction.java
@@ -18,16 +18,26 @@
 
 package org.apache.flink.test.iteration.operators;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.typeutils.MapTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.ProcessFunction;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.OutputTag;
 
 import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.function.BiConsumer;
 
 /**
@@ -35,7 +45,7 @@ import java.util.function.BiConsumer;
  * the received numbers to the next operator.
  */
 public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord, EpochRecord>
-        implements IterationListener<EpochRecord> {
+        implements IterationListener<EpochRecord>, CheckpointedFunction {
 
     private final boolean sync;
 
@@ -47,17 +57,54 @@ public class ReduceAllRoundProcessFunction extends ProcessFunction<EpochRecord,
 
     private transient OutputTag<OutputRecord<Integer>> outputTag;
 
+    private transient ListState<Map<Integer, Integer>> sumByEpochsState;
+
+    private transient ListState<EpochRecord> cachedRecordsState;
+
     public ReduceAllRoundProcessFunction(boolean sync, int maxRound) {
         this.sync = sync;
         this.maxRound = maxRound;
     }
 
     @Override
-    public void open(Configuration parameters) throws Exception {
-        super.open(parameters);
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
         sumByEpochs = new HashMap<>();
         cachedRecords = new ArrayList<>();
-        outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
+
+        sumByEpochsState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "test",
+                                        new MapTypeInfo<>(
+                                                BasicTypeInfo.INT_TYPE_INFO,
+                                                BasicTypeInfo.INT_TYPE_INFO)));
+        Optional<Map<Integer, Integer>> old =
+                OperatorStateUtils.getUniqueElement(sumByEpochsState, "test");
+        old.ifPresent(v -> sumByEpochs.putAll(v));
+
+        cachedRecordsState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("cache", EpochRecord.class));
+        cachedRecordsState.get().forEach(v -> cachedRecords.add(v));
+    }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        sumByEpochsState.clear();
+        sumByEpochsState.update(Collections.singletonList(new HashMap<>(sumByEpochs)));
+
+        cachedRecordsState.clear();
+        cachedRecordsState.addAll(cachedRecords);
+    }
+
+    @Override
+    public void open(Configuration parameters) throws Exception {
+        super.open(parameters);
+        this.outputTag = new OutputTag<OutputRecord<Integer>>("output") {};
     }
 
     @Override
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
index 4054cf6..ed564b0 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/SequenceSource.java
@@ -18,10 +18,19 @@
 
 package org.apache.flink.test.iteration.operators;
 
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 
-/** A source emitting the continuous int sequences. */
-public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
+import java.util.Collections;
+
+/** Sources emitting the continuous int sequences. */
+public class SequenceSource extends RichParallelSourceFunction<EpochRecord>
+        implements CheckpointedFunction {
 
     private final int maxValue;
 
@@ -31,6 +40,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
 
     private volatile boolean canceled;
 
+    private int next;
+
+    private ListState<Integer> nextState;
+
     public SequenceSource(int maxValue, boolean holdAfterMaxValue, int period) {
         this.maxValue = maxValue;
         this.holdAfterMaxValue = holdAfterMaxValue;
@@ -38,9 +51,22 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
     }
 
     @Override
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
+        nextState =
+                functionInitializationContext
+                        .getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("next", Integer.class));
+        next = OperatorStateUtils.getUniqueElement(nextState, "next").orElse(0);
+    }
+
+    @Override
     public void run(SourceContext<EpochRecord> ctx) throws Exception {
-        for (int i = 0; i < maxValue && !canceled; ++i) {
-            ctx.collect(new EpochRecord(0, i));
+        while (next < maxValue && !canceled) {
+            synchronized (ctx.getCheckpointLock()) {
+                ctx.collect(new EpochRecord(0, next++));
+            }
+
             if (period > 0) {
                 Thread.sleep(period);
             }
@@ -57,4 +83,10 @@ public class SequenceSource extends RichParallelSourceFunction<EpochRecord> {
     public void cancel() {
         canceled = true;
     }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        nextState.clear();
+        nextState.update(Collections.singletonList(next));
+    }
 }
diff --git a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
index 35e6876..1fea9a5 100644
--- a/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
+++ b/flink-ml-tests/src/test/java/org/apache/flink/test/iteration/operators/TwoInputReduceAllRoundProcessFunction.java
@@ -20,6 +20,9 @@ package org.apache.flink.test.iteration.operators;
 
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
 import org.apache.flink.util.Collector;
 
@@ -28,7 +31,7 @@ import org.apache.flink.util.Collector;
  */
 public class TwoInputReduceAllRoundProcessFunction
         extends CoProcessFunction<EpochRecord, EpochRecord, EpochRecord>
-        implements IterationListener<EpochRecord> {
+        implements IterationListener<EpochRecord>, CheckpointedFunction {
 
     private final ReduceAllRoundProcessFunction reduceAllRoundProcessFunction;
 
@@ -78,4 +81,15 @@ public class TwoInputReduceAllRoundProcessFunction
         // Processing the first round of messages.
         reduceAllRoundProcessFunction.processRecord(record, ctx::output, out);
     }
+
+    @Override
+    public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
+        reduceAllRoundProcessFunction.snapshotState(functionSnapshotContext);
+    }
+
+    @Override
+    public void initializeState(FunctionInitializationContext functionInitializationContext)
+            throws Exception {
+        reduceAllRoundProcessFunction.initializeState(functionInitializationContext);
+    }
 }

[flink-ml] 06/08: [FLINK-24655][iteration] Skip the repeat round for all-round operator

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 31ffe6c80339163b4bad05a14a8703d20177edb4
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 01:28:27 2021 +0800

    [FLINK-24655][iteration] Skip the repeat round for all-round operator
---
 .../allround/AbstractAllRoundWrapperOperator.java  | 151 ++++++++++++++++++---
 .../OneInputAllRoundWrapperOperatorTest.java       |  70 ++++++++++
 2 files changed, 204 insertions(+), 17 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 2855e38..d3461a1 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -18,28 +18,45 @@
 
 package org.apache.flink.iteration.operator.allround;
 
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.OperatorStateStore;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.AbstractWrapperOperator;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.metrics.groups.OperatorMetricGroup;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
+import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
 import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
 import org.apache.flink.streaming.api.operators.TimestampedCollector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.OutputTag;
 
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+
 import java.io.IOException;
+import java.util.Collections;
 
 import static org.apache.flink.iteration.operator.OperatorUtils.processOperatorOrUdfIfSatisfy;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** The base class for the all-round wrapper operators. */
 public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperator<T>>
@@ -49,6 +66,13 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     private final IterationContext iterationContext;
 
+    // --------------- state ---------------------------
+    private int latestEpochWatermark = -1;
+
+    private ListState<Integer> parallelismState;
+
+    private ListState<Integer> latestEpochWatermarkState;
+
     @SuppressWarnings({"unchecked", "rawtypes"})
     public AbstractAllRoundWrapperOperator(
             StreamOperatorParameters<IterationRecord<T>> parameters,
@@ -75,6 +99,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
+        if (epochWatermark <= latestEpochWatermark) {
+            return;
+        }
+        latestEpochWatermark = epochWatermark;
+
         setIterationContextRound(epochWatermark);
         processOperatorOrUdfIfSatisfy(
                 wrappedOperator,
@@ -100,23 +129,42 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
     }
 
     @Override
-    public void open() throws Exception {
-        wrappedOperator.open();
-    }
+    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
+            throws Exception {
+        RecordingStreamTaskStateInitializer recordingStreamTaskStateInitializer =
+                new RecordingStreamTaskStateInitializer(streamTaskStateManager);
+        wrappedOperator.initializeState(recordingStreamTaskStateInitializer);
+        checkState(recordingStreamTaskStateInitializer.lastCreated != null);
 
-    @Override
-    public void finish() throws Exception {
-        wrappedOperator.finish();
-    }
+        OperatorStateStore operatorStateStore =
+                recordingStreamTaskStateInitializer.lastCreated.operatorStateBackend();
 
-    @Override
-    public void close() throws Exception {
-        wrappedOperator.close();
-    }
+        parallelismState =
+                operatorStateStore.getUnionListState(
+                        new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+                .ifPresent(
+                        oldParallelism ->
+                                checkState(
+                                        oldParallelism
+                                                == containingTask
+                                                        .getEnvironment()
+                                                        .getTaskInfo()
+                                                        .getNumberOfParallelSubtasks(),
+                                        "The all-round wrapper operator is recovered with parallelism changed from "
+                                                + oldParallelism
+                                                + " to "
+                                                + containingTask
+                                                        .getEnvironment()
+                                                        .getTaskInfo()
+                                                        .getNumberOfParallelSubtasks()));
 
-    @Override
-    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
-        wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
+        latestEpochWatermarkState =
+                operatorStateStore.getListState(
+                        new ListStateDescriptor<>("latestEpoch", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(latestEpochWatermarkState, "latestEpoch")
+                .ifPresent(
+                        oldLatestEpochWatermark -> latestEpochWatermark = oldLatestEpochWatermark);
     }
 
     @Override
@@ -126,14 +174,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
             CheckpointOptions checkpointOptions,
             CheckpointStreamFactory storageLocation)
             throws Exception {
+
+        // Always clear the union list state before set value.
+        parallelismState.clear();
+        if (containingTask.getEnvironment().getTaskInfo().getIndexOfThisSubtask() == 0) {
+            parallelismState.update(
+                    Collections.singletonList(
+                            containingTask
+                                    .getEnvironment()
+                                    .getTaskInfo()
+                                    .getNumberOfParallelSubtasks()));
+        }
+        latestEpochWatermarkState.update(Collections.singletonList(latestEpochWatermark));
+
         return wrappedOperator.snapshotState(
                 checkpointId, timestamp, checkpointOptions, storageLocation);
     }
 
     @Override
-    public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
-            throws Exception {
-        wrappedOperator.initializeState(streamTaskStateManager);
+    public void open() throws Exception {
+        wrappedOperator.open();
+    }
+
+    @Override
+    public void finish() throws Exception {
+        wrappedOperator.finish();
+    }
+
+    @Override
+    public void close() throws Exception {
+        wrappedOperator.close();
+    }
+
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        wrappedOperator.prepareSnapshotPreBarrier(checkpointId);
     }
 
     @Override
@@ -176,6 +251,11 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
         return wrappedOperator.getCurrentKey();
     }
 
+    @VisibleForTesting
+    int getLatestEpochWatermark() {
+        return latestEpochWatermark;
+    }
+
     private class IterationContext implements IterationListener.Context {
 
         @Override
@@ -183,4 +263,41 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
             proxyOutput.collect(outputTag, new StreamRecord<>(value));
         }
     }
+
+    private static class RecordingStreamTaskStateInitializer implements StreamTaskStateInitializer {
+
+        private final StreamTaskStateInitializer wrapped;
+
+        StreamOperatorStateContext lastCreated;
+
+        public RecordingStreamTaskStateInitializer(StreamTaskStateInitializer wrapped) {
+            this.wrapped = wrapped;
+        }
+
+        @Override
+        public StreamOperatorStateContext streamOperatorStateContext(
+                @Nonnull OperatorID operatorID,
+                @Nonnull String s,
+                @Nonnull ProcessingTimeService processingTimeService,
+                @Nonnull KeyContext keyContext,
+                @Nullable TypeSerializer<?> typeSerializer,
+                @Nonnull CloseableRegistry closeableRegistry,
+                @Nonnull MetricGroup metricGroup,
+                double v,
+                boolean b)
+                throws Exception {
+            lastCreated =
+                    wrapped.streamOperatorStateContext(
+                            operatorID,
+                            s,
+                            processingTimeService,
+                            keyContext,
+                            typeSerializer,
+                            closeableRegistry,
+                            metricGroup,
+                            v,
+                            b);
+            return lastCreated;
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index f628b65..9ebb975 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -21,12 +21,14 @@ package org.apache.flink.iteration.operator.allround;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.operator.OperatorWrapper;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
@@ -38,7 +40,9 @@ import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
@@ -53,6 +57,7 @@ import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 
 /** Tests the {@link OneInputAllRoundWrapperOperator}. */
 public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
@@ -129,6 +134,71 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
         }
     }
 
+    @Test
+    public void testSnapshotAndRestore() throws Exception {
+        StreamOperatorFactory<IterationRecord<Integer>> wrapperFactory =
+                new RecordingOperatorFactory<>(
+                        SimpleOperatorFactory.of(new LifeCycleTrackingOneInputStreamOperator()),
+                        new AllRoundOperatorWrapper<>());
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot = null;
+        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new,
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+                        .build()) {
+            harness.getTaskStateManager().getWaitForReportLatch().reset();
+            harness.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(5, "fake")));
+            harness.getStreamTask()
+                    .triggerCheckpointAsync(
+                            new CheckpointMetaData(2, 1000),
+                            CheckpointOptions.alignedNoTimeout(
+                                    CheckpointType.CHECKPOINT,
+                                    CheckpointStorageLocationReference.getDefault()));
+            harness.processAll();
+
+            harness.getTaskStateManager().getWaitForReportLatch().await();
+            taskStateSnapshot = harness.getTaskStateManager().getLastJobManagerTaskStateSnapshot();
+        }
+
+        assertNotNull(taskStateSnapshot);
+        try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
+                new StreamTaskMailboxTestHarnessBuilder<>(
+                                OneInputStreamTask::new,
+                                new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setTaskStateSnapshot(2, taskStateSnapshot)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
+                        .setupOutputForSingletonOperatorChain(wrapperFactory, operatorId)
+                        .build()) {
+            assertEquals(
+                    5,
+                    ((AbstractAllRoundWrapperOperator) RecordingOperatorFactory.latest)
+                            .getLatestEpochWatermark());
+        }
+    }
+
+    private static class RecordingOperatorFactory<OUT> extends WrapperOperatorFactory<OUT> {
+
+        static StreamOperator<?> latest = null;
+
+        public RecordingOperatorFactory(
+                StreamOperatorFactory<OUT> operatorFactory,
+                OperatorWrapper<OUT, IterationRecord<OUT>> wrapper) {
+            super(operatorFactory, wrapper);
+        }
+
+        @Override
+        public <T extends StreamOperator<IterationRecord<OUT>>> T createStreamOperator(
+                StreamOperatorParameters<IterationRecord<OUT>> parameters) {
+            latest = super.createStreamOperator(parameters);
+            return (T) latest;
+        }
+    }
+
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
             implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {

[flink-ml] 01/08: [FLINK-24655][iteration] HeadOperator waits for MAX_WATERMARK iterates back before terminating.

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 63a82c24d9aaae07b28d5043058199983da5cee4
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Wed Oct 6 16:19:21 2021 +0800

    [FLINK-24655][iteration] HeadOperator waits for MAX_WATERMARK iterates back before terminating.
    
    This is a basis for the checkpoint since for checkpoints
    with feedback edges, we would need to also include the
    feedback records into snapshot, thus if we want to make
    sure all the checkpoints before the terminated globally
    aligned events get done, we have to wait for one more round.
---
 .../flink/iteration/operator/HeadOperator.java     | 158 ++++++++++++---------
 .../headprocessor/HeadOperatorRecordProcessor.java |  58 ++++++++
 .../operator/headprocessor/HeadOperatorState.java  |  22 +++
 .../RegularHeadOperatorRecordProcessor.java        | 122 ++++++++++++++++
 .../TerminatingHeadOperatorRecordProcessor.java    |  60 ++++++++
 .../flink/iteration/operator/HeadOperatorTest.java |  23 ++-
 6 files changed, 369 insertions(+), 74 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index 0796897..d7a9a54 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.iteration.IterationID;
@@ -27,6 +28,9 @@ import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
+import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
+import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
+import org.apache.flink.iteration.operator.headprocessor.TerminatingHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
@@ -45,17 +49,18 @@ import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
-import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.OutputTag;
 
-import java.util.HashMap;
-import java.util.Map;
+import java.io.IOException;
 import java.util.Objects;
 import java.util.concurrent.Executor;
 
+import static org.apache.flink.util.Preconditions.checkState;
+
 /**
- * The head operators unions the initialized variable stream and the feedback stream, and
- * synchronize the epoch watermark (round).
+ * The head operator unions the initialized variable stream and the feedback stream, and synchronize
+ * the epoch watermark (round).
  */
 public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         implements OneInputStreamOperator<IterationRecord<?>, IterationRecord<?>>,
@@ -76,15 +81,15 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     private final MailboxExecutor mailboxExecutor;
 
-    private final Map<Integer, Long> numFeedbackRecordsPerEpoch;
+    private transient BroadcastOutput<?> eventBroadcastOutput;
 
-    private transient String uniqueSenderId;
+    private transient ContextImpl processorContext;
 
-    private transient BroadcastOutput<?> eventBroadcastOutput;
+    // ------------- runtime -------------------
 
-    private transient StreamRecord<IterationRecord<?>> reusable;
+    private HeadOperatorStatus status;
 
-    private transient boolean shouldTerminate;
+    private HeadOperatorRecordProcessor recordProcessor;
 
     public HeadOperator(
             IterationID iterationId,
@@ -98,7 +103,6 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         this.isCriteriaStream = isCriteriaStream;
         this.mailboxExecutor = Objects.requireNonNull(mailboxExecutor);
         this.operatorEventGateway = Objects.requireNonNull(operatorEventGateway);
-        this.numFeedbackRecordsPerEpoch = new HashMap<>();
 
         // Even though this operator does not use the processing
         // time service, AbstractStreamOperator requires this
@@ -112,9 +116,6 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
             StreamConfig config,
             Output<StreamRecord<IterationRecord<?>>> output) {
         super.setup(containingTask, config, output);
-        uniqueSenderId =
-                OperatorUtils.getUniqueSenderId(
-                        getOperatorID(), getRuntimeContext().getIndexOfThisSubtask());
         eventBroadcastOutput =
                 BroadcastOutputFactory.createBroadcastOutput(
                         output, metrics.getIOMetricGroup().getNumRecordsOutCounter());
@@ -124,12 +125,14 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     public void initializeState(StateInitializationContext context) throws Exception {
         super.initializeState(context);
 
-        reusable = new StreamRecord<>(null);
+        processorContext = new ContextImpl();
+        status = HeadOperatorStatus.RUNNING;
+        recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
 
         // Here we register a mail
         registerFeedbackConsumer(
                 (Runnable runnable) -> {
-                    if (!shouldTerminate) {
+                    if (status != HeadOperatorStatus.TERMINATED) {
                         mailboxExecutor.execute(runnable::run, "Head feedback");
                     }
                 });
@@ -137,73 +140,39 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
-        processRecord(element);
+        recordProcessor.processElement(element);
     }
 
     @Override
     public void processFeedback(StreamRecord<IterationRecord<?>> iterationRecord) throws Exception {
-        if (iterationRecord.getValue().getType() == IterationRecord.Type.RECORD) {
-            numFeedbackRecordsPerEpoch.compute(
-                    iterationRecord.getValue().getEpoch(),
-                    (round, count) -> count == null ? 1 : count + 1);
-        }
-        processRecord(iterationRecord);
-    }
-
-    private void processRecord(StreamRecord<IterationRecord<?>> iterationRecord) {
-        switch (iterationRecord.getValue().getType()) {
-            case RECORD:
-                reusable.replace(iterationRecord.getValue(), iterationRecord.getTimestamp());
-                output.collect(reusable);
-                break;
-            case EPOCH_WATERMARK:
-                LOG.debug(
-                        "Head Received epoch watermark {}", iterationRecord.getValue().getEpoch());
-                sendEpochWatermarkToCoordinator(iterationRecord.getValue().getEpoch());
-                break;
+        boolean terminated = recordProcessor.processFeedbackElement(iterationRecord);
+        if (terminated) {
+            checkState(status == HeadOperatorStatus.TERMINATING);
+            status = HeadOperatorStatus.TERMINATED;
         }
     }
 
     @Override
-    @SuppressWarnings({"unchecked", "rawtypes"})
     public void handleOperatorEvent(OperatorEvent operatorEvent) {
         if (operatorEvent instanceof GloballyAlignedEvent) {
-            try {
-                GloballyAlignedEvent globallyAlignedEvent = (GloballyAlignedEvent) operatorEvent;
-                LOG.info("Received global event {}", globallyAlignedEvent);
-
-                shouldTerminate = globallyAlignedEvent.isTerminated();
-                reusable.replace(
-                        IterationRecord.newEpochWatermark(
-                                globallyAlignedEvent.isTerminated()
-                                        ? Integer.MAX_VALUE
-                                        : globallyAlignedEvent.getEpoch(),
-                                uniqueSenderId),
-                        0);
-                eventBroadcastOutput.broadcastEmit((StreamRecord) reusable);
-                numFeedbackRecordsPerEpoch.remove(globallyAlignedEvent.getEpoch());
-            } catch (Exception e) {
-                ExceptionUtils.rethrow(e);
+            boolean shouldTerminate =
+                    recordProcessor.onGloballyAligned((GloballyAlignedEvent) operatorEvent);
+            if (shouldTerminate) {
+                status = HeadOperatorStatus.TERMINATING;
+                recordProcessor = new TerminatingHeadOperatorRecordProcessor();
             }
         }
     }
 
     @Override
     public void endInput() throws Exception {
-        sendEpochWatermarkToCoordinator(0);
-        while (!shouldTerminate) {
+        recordProcessor.processElement(
+                new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+        while (status != HeadOperatorStatus.TERMINATED) {
             mailboxExecutor.yield();
         }
     }
 
-    private void sendEpochWatermarkToCoordinator(int round) {
-        operatorEventGateway.sendEventToCoordinator(
-                new SubtaskAlignedEvent(
-                        round,
-                        numFeedbackRecordsPerEpoch.getOrDefault(round, 0L),
-                        isCriteriaStream));
-    }
-
     private void registerFeedbackConsumer(Executor mailboxExecutor) {
         int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
         int attemptNum = getRuntimeContext().getAttemptNumber();
@@ -217,11 +186,6 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     }
 
     @VisibleForTesting
-    Map<Integer, Long> getNumFeedbackRecordsPerEpoch() {
-        return numFeedbackRecordsPerEpoch;
-    }
-
-    @VisibleForTesting
     public OperatorEventGateway getOperatorEventGateway() {
         return operatorEventGateway;
     }
@@ -230,4 +194,62 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     MailboxExecutor getMailboxExecutor() {
         return mailboxExecutor;
     }
+
+    @VisibleForTesting
+    HeadOperatorRecordProcessor getRecordProcessor() {
+        return recordProcessor;
+    }
+
+    @VisibleForTesting
+    public HeadOperatorStatus getStatus() {
+        return status;
+    }
+
+    @VisibleForTesting
+    enum HeadOperatorStatus {
+        RUNNING,
+
+        TERMINATING,
+
+        TERMINATED
+    }
+
+    private class ContextImpl implements HeadOperatorRecordProcessor.Context {
+
+        @Override
+        public StreamConfig getStreamConfig() {
+            return HeadOperator.this.config;
+        }
+
+        @Override
+        public TaskInfo getTaskInfo() {
+            return getContainingTask().getEnvironment().getTaskInfo();
+        }
+
+        @Override
+        public void output(StreamRecord<IterationRecord<?>> record) {
+            output.collect(record);
+        }
+
+        @Override
+        public void output(
+                OutputTag<IterationRecord<?>> outputTag, StreamRecord<IterationRecord<?>> record) {
+            output.collect(outputTag, record);
+        }
+
+        @Override
+        public void broadcastOutput(StreamRecord<IterationRecord<?>> record) {
+            try {
+                eventBroadcastOutput.broadcastEmit((StreamRecord) record);
+            } catch (IOException e) {
+                throw new FlinkRuntimeException("Failed to broadcast event", e);
+            }
+        }
+
+        @Override
+        public void updateEpochToCoordinator(int epoch, long numFeedbackRecords) {
+            operatorEventGateway.sendEventToCoordinator(
+                    new SubtaskAlignedEvent(epoch, numFeedbackRecords, isCriteriaStream));
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
new file mode 100644
index 0000000..78132c3
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.headprocessor;
+
+import org.apache.flink.api.common.TaskInfo;
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.HeadOperator;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.OutputTag;
+
+/** The component to actually deal with the event received in the {@link HeadOperator}. */
+public interface HeadOperatorRecordProcessor {
+
+    void initializeState(HeadOperatorState headOperatorState) throws Exception;
+
+    void processElement(StreamRecord<IterationRecord<?>> record);
+
+    boolean processFeedbackElement(StreamRecord<IterationRecord<?>> record);
+
+    boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent);
+
+    HeadOperatorState snapshotState();
+
+    /** The context for {@link HeadOperatorRecordProcessor}. */
+    interface Context {
+
+        StreamConfig getStreamConfig();
+
+        TaskInfo getTaskInfo();
+
+        void output(StreamRecord<IterationRecord<?>> record);
+
+        void output(
+                OutputTag<IterationRecord<?>> outputTag, StreamRecord<IterationRecord<?>> record);
+
+        void broadcastOutput(StreamRecord<IterationRecord<?>> record);
+
+        void updateEpochToCoordinator(int epoch, long numFeedbackRecords);
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
new file mode 100644
index 0000000..861f481
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.headprocessor;
+
+/** The state entry for the head operator. */
+public class HeadOperatorState {}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
new file mode 100644
index 0000000..f1a6b0f
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.headprocessor;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.OperatorUtils;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Processes the event before we received the terminated global aligned event from the coordinator.
+ */
+public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
+
+    protected static final Logger LOG =
+            LoggerFactory.getLogger(RegularHeadOperatorRecordProcessor.class);
+
+    private final Context headOperatorContext;
+
+    private final StreamRecord<IterationRecord<?>> reusable;
+
+    private final Map<Integer, Long> numFeedbackRecordsPerEpoch;
+
+    private final String senderId;
+
+    public RegularHeadOperatorRecordProcessor(Context headOperatorContext) {
+        this.headOperatorContext = headOperatorContext;
+
+        this.reusable = new StreamRecord<>(null);
+        this.numFeedbackRecordsPerEpoch = new HashMap<>();
+
+        this.senderId =
+                OperatorUtils.getUniqueSenderId(
+                        headOperatorContext.getStreamConfig().getOperatorID(),
+                        headOperatorContext.getTaskInfo().getIndexOfThisSubtask());
+    }
+
+    @Override
+    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+
+    @Override
+    public void processElement(StreamRecord<IterationRecord<?>> element) {
+        processRecord(element);
+    }
+
+    @Override
+    public boolean processFeedbackElement(StreamRecord<IterationRecord<?>> element) {
+        if (element.getValue().getType() == IterationRecord.Type.RECORD) {
+            numFeedbackRecordsPerEpoch.compute(
+                    element.getValue().getEpoch(), (epoch, count) -> count == null ? 1 : count + 1);
+        }
+
+        processRecord(element);
+
+        return false;
+    }
+
+    @Override
+    public boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent) {
+        LOG.info("Received global event {}", globallyAlignedEvent);
+
+        reusable.replace(
+                IterationRecord.newEpochWatermark(
+                        globallyAlignedEvent.isTerminated()
+                                ? Integer.MAX_VALUE
+                                : globallyAlignedEvent.getEpoch(),
+                        senderId),
+                0);
+        headOperatorContext.broadcastOutput(reusable);
+
+        return globallyAlignedEvent.isTerminated();
+    }
+
+    @Override
+    public HeadOperatorState snapshotState() {
+        return new HeadOperatorState();
+    }
+
+    @VisibleForTesting
+    public Map<Integer, Long> getNumFeedbackRecordsPerEpoch() {
+        return numFeedbackRecordsPerEpoch;
+    }
+
+    private void processRecord(StreamRecord<IterationRecord<?>> iterationRecord) {
+        switch (iterationRecord.getValue().getType()) {
+            case RECORD:
+                reusable.replace(iterationRecord.getValue(), iterationRecord.getTimestamp());
+                headOperatorContext.output(reusable);
+                break;
+            case EPOCH_WATERMARK:
+                LOG.info("Head Received epoch watermark {}", iterationRecord.getValue().getEpoch());
+                headOperatorContext.updateEpochToCoordinator(
+                        iterationRecord.getValue().getEpoch(),
+                        numFeedbackRecordsPerEpoch.getOrDefault(
+                                iterationRecord.getValue().getEpoch(), 0L));
+                break;
+        }
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
new file mode 100644
index 0000000..c5377e5
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.operator.headprocessor;
+
+import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.FlinkRuntimeException;
+
+/**
+ * Processor used after we received terminated globally aligned event from the coordinator, but
+ * before we received the (Integer.MAX_VALUE + 1) from the feedback channel again.
+ */
+public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
+
+    @Override
+    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+
+    @Override
+    public void processElement(StreamRecord<IterationRecord<?>> record) {
+        throw new FlinkRuntimeException(
+                "It is not possible to receive the element from normal input during terminating.");
+    }
+
+    @Override
+    public boolean processFeedbackElement(StreamRecord<IterationRecord<?>> record) {
+        if (record.getValue().getType() == IterationRecord.Type.EPOCH_WATERMARK) {
+            return record.getValue().getEpoch() == Integer.MAX_VALUE + 1;
+        }
+
+        return false;
+    }
+
+    @Override
+    public boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent) {
+        throw new FlinkRuntimeException(
+                "It is not possible to receive the globally aligned event from normal input during terminating.");
+    }
+
+    @Override
+    public HeadOperatorState snapshotState() {
+        return new HeadOperatorState();
+    }
+}
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index cc3ce31..f54422e 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
+import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -86,12 +87,12 @@ public class HeadOperatorTest extends TestLogger {
                             new StreamRecord<>(IterationRecord.newRecord(2, 0), 3),
                             new StreamRecord<>(IterationRecord.newRecord(4, 1), 4));
             assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
-            assertEquals(
-                    2,
-                    (long)
-                            RecordingHeadOperatorFactory.latestHeadOperator
-                                    .getNumFeedbackRecordsPerEpoch()
-                                    .get(1));
+
+            RegularHeadOperatorRecordProcessor recordProcessor =
+                    (RegularHeadOperatorRecordProcessor)
+                            RecordingHeadOperatorFactory.latestHeadOperator.getRecordProcessor();
+
+            assertEquals(2, (long) recordProcessor.getNumFeedbackRecordsPerEpoch().get(1));
         }
     }
 
@@ -153,6 +154,16 @@ public class HeadOperatorTest extends TestLogger {
                                                     new SerializedValue<>(
                                                             new GloballyAlignedEvent(1, true)));
 
+                                    while (RecordingHeadOperatorFactory.latestHeadOperator
+                                                    .getStatus()
+                                            == HeadOperator.HeadOperatorStatus.RUNNING) {}
+                                    putFeedbackRecords(
+                                            iterationId,
+                                            0,
+                                            new StreamRecord<>(
+                                                    IterationRecord.newEpochWatermark(
+                                                            Integer.MAX_VALUE + 1, "tail")));
+
                                     return null;
                                 } catch (Throwable e) {
                                     RecordingHeadOperatorFactory.latestHeadOperator

[flink-ml] 04/08: [FLINK-24655][iteration] Support snapshot the feedback records on checkpoint

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 7c99864cdc85d61c08a3668ddeea07cdd3aa9bd4
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Mon Oct 4 21:21:51 2021 +0800

    [FLINK-24655][iteration] Support snapshot the feedback records on checkpoint
---
 .../flink/iteration/checkpoint/Checkpoints.java    | 130 ++++++++++++
 .../datacache/nonkeyed/DataCacheSnapshot.java      | 224 +++++++++++++++++++++
 .../datacache/nonkeyed/DataCacheWriter.java        |  20 +-
 .../iteration/datacache/nonkeyed/Segment.java      |   5 +
 .../flink/iteration/operator/OperatorUtils.java    |   5 +-
 .../datacache/nonkeyed/DataCacheSnapshotTest.java  | 213 ++++++++++++++++++++
 6 files changed, 592 insertions(+), 5 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
new file mode 100644
index 0000000..03420f8
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
@@ -0,0 +1,130 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.checkpoint;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
+import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.util.ResourceGuard;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+/** Maintains the pending checkpoints. */
+public class Checkpoints<T> implements AutoCloseable {
+
+    private static final Logger LOG = LoggerFactory.getLogger(Checkpoints.class);
+
+    private final TypeSerializer<T> typeSerializer;
+    private final FileSystem fileSystem;
+    private final SupplierWithException<Path, IOException> pathSupplier;
+
+    private final TreeMap<Long, PendingCheckpoint> uncompletedCheckpoints = new TreeMap<>();
+
+    public Checkpoints(
+            TypeSerializer<T> typeSerializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathSupplier) {
+        this.typeSerializer = typeSerializer;
+        this.fileSystem = fileSystem;
+        this.pathSupplier = pathSupplier;
+    }
+
+    public void startLogging(long checkpointId, OperatorStateCheckpointOutputStream outputStream)
+            throws IOException {
+        DataCacheWriter<T> dataCacheWriter =
+                new DataCacheWriter<>(typeSerializer, fileSystem, pathSupplier);
+        ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
+        uncompletedCheckpoints.put(
+                checkpointId, new PendingCheckpoint(dataCacheWriter, outputStream, snapshotLease));
+    }
+
+    public void append(T element) throws IOException {
+        for (PendingCheckpoint pendingCheckpoint : uncompletedCheckpoints.values()) {
+            pendingCheckpoint.dataCacheWriter.addRecord(element);
+        }
+    }
+
+    public void commitCheckpointsUntil(long checkpointId) {
+        SortedMap<Long, PendingCheckpoint> completedCheckpoints =
+                uncompletedCheckpoints.headMap(checkpointId, true);
+        completedCheckpoints
+                .values()
+                .forEach(
+                        pendingCheckpoint -> {
+                            try {
+                                pendingCheckpoint.dataCacheWriter.finish();
+                                DataCacheSnapshot snapshot =
+                                        new DataCacheSnapshot(
+                                                fileSystem,
+                                                null,
+                                                pendingCheckpoint.dataCacheWriter
+                                                        .getFinishSegments());
+                                pendingCheckpoint.checkpointOutputStream.startNewPartition();
+                                snapshot.writeTo(pendingCheckpoint.checkpointOutputStream);
+                                pendingCheckpoint.dataCacheWriter.cleanup();
+                            } catch (Exception e) {
+                                LOG.error("Failed to commit checkpoint until " + checkpointId, e);
+                            } finally {
+                                pendingCheckpoint.snapshotLease.close();
+                            }
+                        });
+
+        completedCheckpoints.clear();
+    }
+
+    @Override
+    public void close() {
+        uncompletedCheckpoints.forEach(
+                (checkpointId, pendingCheckpoint) -> {
+                    pendingCheckpoint.snapshotLease.close();
+                    try {
+                        pendingCheckpoint.dataCacheWriter.cleanup();
+                    } catch (IOException e) {
+                        LOG.error("Failed to cleanup " + checkpointId, e);
+                    }
+                });
+        uncompletedCheckpoints.clear();
+    }
+
+    private class PendingCheckpoint {
+        final DataCacheWriter<T> dataCacheWriter;
+
+        final OperatorStateCheckpointOutputStream checkpointOutputStream;
+
+        final ResourceGuard.Lease snapshotLease;
+
+        public PendingCheckpoint(
+                DataCacheWriter<T> dataCacheWriter,
+                OperatorStateCheckpointOutputStream checkpointOutputStream,
+                ResourceGuard.Lease snapshotLease) {
+            this.dataCacheWriter = dataCacheWriter;
+            this.checkpointOutputStream = checkpointOutputStream;
+            this.snapshotLease = snapshotLease;
+        }
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java
new file mode 100644
index 0000000..f837bd6
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshot.java
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.fs.FSDataOutputStream;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.runtime.util.NonClosingInputStreamDecorator;
+import org.apache.flink.runtime.util.NonClosingOutpusStreamDecorator;
+import org.apache.flink.statefun.flink.core.feedback.FeedbackConsumer;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.function.SupplierWithException;
+
+import org.apache.commons.io.input.BoundedInputStream;
+
+import javax.annotation.Nullable;
+
+import java.io.DataInputStream;
+import java.io.DataOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** The snapshot of a data cache. It could be written out or read from an external stream.O */
+public class DataCacheSnapshot {
+
+    private static final int CURRENT_VERSION = 1;
+
+    private final FileSystem fileSystem;
+
+    @Nullable private final Tuple2<Integer, Integer> readerPosition;
+
+    private final List<Segment> segments;
+
+    public DataCacheSnapshot(
+            FileSystem fileSystem,
+            @Nullable Tuple2<Integer, Integer> readerPosition,
+            List<Segment> segments) {
+        this.fileSystem = fileSystem;
+        this.readerPosition = readerPosition;
+        this.segments = segments;
+    }
+
+    public FileSystem getFileSystem() {
+        return fileSystem;
+    }
+
+    @Nullable
+    public Tuple2<Integer, Integer> getReaderPosition() {
+        return readerPosition;
+    }
+
+    public List<Segment> getSegments() {
+        return segments;
+    }
+
+    public void writeTo(OutputStream checkpointOutputStream) throws IOException {
+        try (DataOutputStream dos =
+                new DataOutputStream(new NonClosingOutpusStreamDecorator(checkpointOutputStream))) {
+            dos.writeInt(CURRENT_VERSION);
+            dos.writeBoolean(readerPosition != null);
+            if (readerPosition != null) {
+                dos.writeInt(readerPosition.f0);
+                dos.writeInt(readerPosition.f1);
+            }
+
+            dos.writeBoolean(fileSystem.isDistributedFS());
+            if (fileSystem.isDistributedFS()) {
+                // We only need to record the segments itself
+                serializeSegments(segments, dos);
+            } else {
+                // We have to copy the whole streams.
+                int totalRecords = segments.stream().mapToInt(Segment::getCount).sum();
+                long totalSize = segments.stream().mapToLong(Segment::getSize).sum();
+                checkState(totalRecords >= 0, "overflowed: " + totalRecords);
+                dos.writeInt(totalRecords);
+                dos.writeLong(totalSize);
+
+                for (Segment segment : segments) {
+                    try (FSDataInputStream inputStream = fileSystem.open(segment.getPath())) {
+                        IOUtils.copyBytes(inputStream, checkpointOutputStream, false);
+                    }
+                }
+            }
+        }
+    }
+
+    public static <T> void replay(
+            InputStream checkpointInputStream,
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            FeedbackConsumer<T> feedbackConsumer)
+            throws Exception {
+        try (DataInputStream dis =
+                new DataInputStream(new NonClosingInputStreamDecorator(checkpointInputStream))) {
+            int version = dis.readInt();
+            checkState(
+                    version == CURRENT_VERSION,
+                    "Currently only support version " + CURRENT_VERSION);
+            parseReaderPosition(dis);
+
+            boolean isDistributedFS = dis.readBoolean();
+            if (isDistributedFS) {
+                List<Segment> segments = deserializeSegments(dis);
+                DataCacheReader<T> dataCacheReader =
+                        new DataCacheReader<T>(serializer, fileSystem, segments);
+                while (dataCacheReader.hasNext()) {
+                    feedbackConsumer.processFeedback(dataCacheReader.next());
+                }
+            } else {
+                DataInputViewStreamWrapper dataInputView = new DataInputViewStreamWrapper(dis);
+                int totalRecords = dis.readInt();
+                // Ignore the total size.
+                dis.readLong();
+                for (int i = 0; i < totalRecords; ++i) {
+                    feedbackConsumer.processFeedback(serializer.deserialize(dataInputView));
+                }
+            }
+        }
+    }
+
+    public static DataCacheSnapshot recover(
+            InputStream checkpointInputStream,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
+            throws IOException {
+        try (DataInputStream dis =
+                new DataInputStream(new NonClosingInputStreamDecorator(checkpointInputStream))) {
+            int version = dis.readInt();
+            checkState(
+                    version == CURRENT_VERSION,
+                    "Currently only support version " + CURRENT_VERSION);
+            Tuple2<Integer, Integer> readerPosition = parseReaderPosition(dis);
+
+            boolean isDistributedFS = dis.readBoolean();
+            checkState(
+                    isDistributedFS == fileSystem.isDistributedFS(),
+                    "Currently we do not support changing the cache file system. "
+                            + "If required, please manually copy the directory from one filesystem to another.");
+
+            List<Segment> segments;
+            if (isDistributedFS) {
+                segments = deserializeSegments(dis);
+            } else {
+                int totalRecords = dis.readInt();
+                long totalSize = dis.readLong();
+
+                Path path = pathGenerator.get();
+                try (FSDataOutputStream outputStream =
+                        fileSystem.create(path, FileSystem.WriteMode.NO_OVERWRITE)) {
+
+                    BoundedInputStream inputStream =
+                            new BoundedInputStream(checkpointInputStream, totalSize);
+                    inputStream.setPropagateClose(false);
+                    IOUtils.copyBytes(inputStream, outputStream, false);
+                    inputStream.close();
+                }
+                segments = Collections.singletonList(new Segment(path, totalRecords, totalSize));
+            }
+
+            return new DataCacheSnapshot(fileSystem, readerPosition, segments);
+        }
+    }
+
+    private static Tuple2<Integer, Integer> parseReaderPosition(DataInputStream dataInputStream)
+            throws IOException {
+        Tuple2<Integer, Integer> readerPosition = null;
+        boolean hasReaderPosition = dataInputStream.readBoolean();
+        if (hasReaderPosition) {
+            readerPosition = new Tuple2<>(dataInputStream.readInt(), dataInputStream.readInt());
+        }
+
+        return readerPosition;
+    }
+
+    private static void serializeSegments(List<Segment> segments, DataOutputStream dataOutputStream)
+            throws IOException {
+        dataOutputStream.writeInt(segments.size());
+        for (int i = 0; i < segments.size(); ++i) {
+            dataOutputStream.writeUTF(segments.get(i).getPath().toString());
+            dataOutputStream.writeInt(segments.get(i).getCount());
+            dataOutputStream.writeLong(segments.get(i).getSize());
+        }
+    }
+
+    private static List<Segment> deserializeSegments(DataInputStream dataInputStream)
+            throws IOException {
+        List<Segment> segments = new ArrayList<>();
+        int numberOfSegments = dataInputStream.readInt();
+        for (int i = 0; i < numberOfSegments; ++i) {
+            segments.add(
+                    new Segment(
+                            new Path(dataInputStream.readUTF()),
+                            dataInputStream.readInt(),
+                            dataInputStream.readLong()));
+        }
+        return segments;
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
index 7a124b7..35256eb 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheWriter.java
@@ -24,12 +24,12 @@ import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.util.function.SupplierWithException;
 
 import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Optional;
-import java.util.function.Supplier;
 
 /** Records the data received and replayed them on required. */
 public class DataCacheWriter<T> {
@@ -38,14 +38,16 @@ public class DataCacheWriter<T> {
 
     private final FileSystem fileSystem;
 
-    private final Supplier<Path> pathGenerator;
+    private final SupplierWithException<Path, IOException> pathGenerator;
 
     private final List<Segment> finishSegments;
 
     private SegmentWriter currentSegment;
 
     public DataCacheWriter(
-            TypeSerializer<T> serializer, FileSystem fileSystem, Supplier<Path> pathGenerator)
+            TypeSerializer<T> serializer,
+            FileSystem fileSystem,
+            SupplierWithException<Path, IOException> pathGenerator)
             throws IOException {
         this.serializer = serializer;
         this.fileSystem = fileSystem;
@@ -69,6 +71,10 @@ public class DataCacheWriter<T> {
         return finishSegments;
     }
 
+    public FileSystem getFileSystem() {
+        return fileSystem;
+    }
+
     public List<Segment> getFinishSegments() {
         return finishSegments;
     }
@@ -76,6 +82,7 @@ public class DataCacheWriter<T> {
     private void finishCurrentSegment(boolean newSegment) throws IOException {
         if (currentSegment != null) {
             currentSegment.finish().ifPresent(finishSegments::add);
+            currentSegment = null;
         }
 
         if (newSegment) {
@@ -118,4 +125,11 @@ public class DataCacheWriter<T> {
             }
         }
     }
+
+    public void cleanup() throws IOException {
+        finishCurrentSegment();
+        for (Segment segment : finishSegments) {
+            fileSystem.delete(segment.getPath(), false);
+        }
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java
index 9b3c1e9..5154f6f 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/datacache/nonkeyed/Segment.java
@@ -70,4 +70,9 @@ public class Segment implements Serializable {
     public int hashCode() {
         return Objects.hash(path, count, size);
     }
+
+    @Override
+    public String toString() {
+        return "Segment{" + "path=" + path + ", count=" + count + ", size=" + size + '}';
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
index 98121c1..292a90e 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorUtils.java
@@ -30,13 +30,14 @@ import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
 import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.util.ExceptionUtils;
+import org.apache.flink.util.function.SupplierWithException;
 import org.apache.flink.util.function.ThrowingConsumer;
 
+import java.io.IOException;
 import java.util.Arrays;
 import java.util.Random;
 import java.util.UUID;
 import java.util.concurrent.Executor;
-import java.util.function.Supplier;
 
 /** Utility class for operators. */
 public class OperatorUtils {
@@ -92,7 +93,7 @@ public class OperatorUtils {
         return new Path(pathStr);
     }
 
-    public static Supplier<Path> createDataCacheFileGenerator(
+    public static SupplierWithException<Path, IOException> createDataCacheFileGenerator(
             Path basePath, String fileTypeName, OperatorID operatorId) {
         return () ->
                 new Path(
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshotTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshotTest.java
new file mode 100644
index 0000000..615bdf5
--- /dev/null
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/datacache/nonkeyed/DataCacheSnapshotTest.java
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.datacache.nonkeyed;
+
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.fs.hdfs.HadoopFileSystem;
+import org.apache.flink.util.OperatingSystem;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.hdfs.MiniDFSCluster;
+import org.junit.AfterClass;
+import org.junit.Assume;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.UnsupportedEncodingException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests the behavior of the {@link DataCacheSnapshot}. */
+@RunWith(Parameterized.class)
+public class DataCacheSnapshotTest extends TestLogger {
+
+    @ClassRule public static final TemporaryFolder CLASS_TEMPORARY_FOLDER = new TemporaryFolder();
+
+    private static MiniDFSCluster hdfsCluster;
+
+    private final FileSystem fileSystem;
+
+    private final Path basePath;
+
+    @BeforeClass
+    public static void createHDFS() throws Exception {
+        Assume.assumeTrue(!OperatingSystem.isWindows());
+
+        Configuration hdfsConfig = new Configuration();
+        hdfsConfig.set(
+                MiniDFSCluster.HDFS_MINIDFS_BASEDIR,
+                CLASS_TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        hdfsCluster = new MiniDFSCluster.Builder(hdfsConfig).build();
+    }
+
+    @AfterClass
+    public static void destroyHDFS() {
+        if (hdfsCluster != null) {
+            hdfsCluster.shutdown();
+        }
+
+        hdfsCluster = null;
+    }
+
+    @Parameterized.Parameters(name = "{0}")
+    public static Object[][] testData() throws IOException {
+        return new Object[][] {new Object[] {"local"}, new Object[] {"hdfs"}};
+    }
+
+    public DataCacheSnapshotTest(String fileSystemType) throws IOException {
+        if (fileSystemType.equals("local")) {
+            fileSystem = FileSystem.getLocalFileSystem();
+            basePath = new Path("file://" + CLASS_TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        } else if (fileSystemType.equals("hdfs")) {
+            fileSystem = new HadoopFileSystem(hdfsCluster.getNewFileSystemInstance(0));
+            basePath =
+                    new Path(hdfsCluster.getURI().toString() + "/" + UUID.randomUUID().toString());
+        } else {
+            throw new UnsupportedEncodingException("Unsupported fs type: " + fileSystemType);
+        }
+    }
+
+    @Test
+    public void testWithoutReaderPosition() throws Exception {
+        int[] numRecordsPerSegment = {100, 200, 300};
+        DataCacheWriter<Integer> writer = createWriterAndAddRecords(numRecordsPerSegment);
+        DataCacheSnapshot dataCacheSnapshot =
+                new DataCacheSnapshot(fileSystem, null, writer.getFinishSegments());
+        checkWriteAndRecoverAndReplay(numRecordsPerSegment, dataCacheSnapshot);
+    }
+
+    @Test
+    public void testWithReadPosition() throws Exception {
+        int[] numRecordsPerSegment = {100, 200, 300};
+        DataCacheWriter<Integer> writer = createWriterAndAddRecords(numRecordsPerSegment);
+        DataCacheSnapshot dataCacheSnapshot =
+                new DataCacheSnapshot(fileSystem, new Tuple2<>(0, 50), writer.getFinishSegments());
+        checkWriteAndRecoverAndReplay(numRecordsPerSegment, dataCacheSnapshot);
+    }
+
+    @Test
+    public void testSnapshotMultipleWritersIntoSingleStream() throws Exception {
+        int[] numRecordsPerSegment = {100, 200, 300};
+        DataCacheWriter<Integer> writer1 = createWriterAndAddRecords(numRecordsPerSegment);
+        DataCacheWriter<Integer> writer2 = createWriterAndAddRecords(numRecordsPerSegment);
+
+        checkWriteAndRecoverAndReplay(
+                numRecordsPerSegment,
+                new DataCacheSnapshot(fileSystem, null, writer1.getFinishSegments()),
+                new DataCacheSnapshot(fileSystem, null, writer2.getFinishSegments()));
+    }
+
+    private DataCacheWriter<Integer> createWriterAndAddRecords(int[] numRecordsPerSegment)
+            throws IOException {
+        DataCacheWriter<Integer> writer =
+                new DataCacheWriter<>(
+                        IntSerializer.INSTANCE,
+                        fileSystem,
+                        () -> new Path(basePath, "writer." + UUID.randomUUID().toString()));
+        int nextNumber = 0;
+        for (int numRecord : numRecordsPerSegment) {
+            for (int i = 0; i < numRecord; ++i) {
+                writer.addRecord(nextNumber++);
+            }
+            writer.finishCurrentSegment();
+        }
+        writer.finish();
+        return writer;
+    }
+
+    private void checkWriteAndRecoverAndReplay(
+            int[] numRecordsPerSegment, DataCacheSnapshot... dataCacheSnapshots) throws Exception {
+        ByteArrayOutputStream bos = new ByteArrayOutputStream();
+        for (DataCacheSnapshot dataCacheSnapshot : dataCacheSnapshots) {
+            dataCacheSnapshot.writeTo(bos);
+        }
+
+        byte[] data = bos.toByteArray();
+
+        ByteArrayInputStream recoverInputStream = new ByteArrayInputStream(data);
+        for (DataCacheSnapshot dataCacheSnapshot : dataCacheSnapshots) {
+            checkRecover(dataCacheSnapshot, recoverInputStream);
+        }
+
+        ByteArrayInputStream replayInputStream = new ByteArrayInputStream(data);
+        for (DataCacheSnapshot dataCacheSnapshot : dataCacheSnapshots) {
+            checkReplay(dataCacheSnapshot, replayInputStream, numRecordsPerSegment);
+        }
+    }
+
+    private void checkRecover(DataCacheSnapshot dataCacheSnapshot, InputStream inputStream)
+            throws IOException {
+        DataCacheSnapshot copied =
+                DataCacheSnapshot.recover(
+                        inputStream,
+                        dataCacheSnapshot.getFileSystem(),
+                        () -> new Path(basePath, "writer." + UUID.randomUUID().toString()));
+        if (dataCacheSnapshot.getFileSystem().isDistributedFS()) {
+            assertEquals(dataCacheSnapshot.getSegments(), copied.getSegments());
+        } else {
+            assertEquals(readElements(dataCacheSnapshot), readElements(copied));
+        }
+
+        assertEquals(dataCacheSnapshot.getReaderPosition(), copied.getReaderPosition());
+    }
+
+    private void checkReplay(
+            DataCacheSnapshot dataCacheSnapshot,
+            InputStream inputStream,
+            int[] numRecordsPerSegment)
+            throws Exception {
+        List<Integer> elements = new ArrayList<>();
+        DataCacheSnapshot.replay(inputStream, IntSerializer.INSTANCE, fileSystem, elements::add);
+
+        int totalRecords = IntStream.of(numRecordsPerSegment).sum();
+        assertEquals(
+                IntStream.range(0, totalRecords).boxed().collect(Collectors.toList()), elements);
+    }
+
+    private List<Integer> readElements(DataCacheSnapshot dataCacheSnapshot) throws IOException {
+        DataCacheReader<Integer> reader =
+                new DataCacheReader<>(
+                        IntSerializer.INSTANCE,
+                        dataCacheSnapshot.getFileSystem(),
+                        dataCacheSnapshot.getSegments());
+        List<Integer> result = new ArrayList<>();
+        while (reader.hasNext()) {
+            result.add(reader.next());
+        }
+
+        return result;
+    }
+}