You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:43 UTC

[flink-ml] 05/08: [FLINK-24655][iteration] Support the checkpoints for the iteration

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 4b47e5df79b35336c4271f9dce1cbac9781451ef
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 00:39:49 2021 +0800

    [FLINK-24655][iteration] Support the checkpoints for the iteration
---
 .../flink/iteration/checkpoint/Checkpoints.java    |  77 +++-
 .../iteration/checkpoint/CheckpointsBroker.java    |  57 +++
 .../flink/iteration/operator/HeadOperator.java     | 249 ++++++++++-
 .../operator/HeadOperatorCheckpointAligner.java    |  26 ++
 .../iteration/operator/HeadOperatorFactory.java    |   8 +
 ...dOperatorState.java => OperatorStateUtils.java} |  26 +-
 .../flink/iteration/operator/TailOperator.java     |  27 ++
 .../coordinator/HeadOperatorCoordinator.java       |  10 +-
 .../coordinator/SharedProgressAligner.java         |  28 ++
 .../headprocessor/HeadOperatorRecordProcessor.java |   7 +-
 .../operator/headprocessor/HeadOperatorState.java  |  32 +-
 .../RegularHeadOperatorRecordProcessor.java        | 110 ++++-
 .../TerminatingHeadOperatorRecordProcessor.java    |   7 +-
 .../flink/iteration/IterationConstructionTest.java |  38 ++
 .../flink/iteration/operator/HeadOperatorTest.java | 496 ++++++++++++++++++++-
 15 files changed, 1154 insertions(+), 44 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
index 03420f8..edbfeba 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/Checkpoints.java
@@ -19,11 +19,13 @@
 package org.apache.flink.iteration.checkpoint;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.core.fs.FileSystem;
 import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
 import org.apache.flink.iteration.datacache.nonkeyed.DataCacheWriter;
 import org.apache.flink.runtime.state.OperatorStateCheckpointOutputStream;
+import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.ResourceGuard;
 import org.apache.flink.util.function.SupplierWithException;
 
@@ -33,6 +35,9 @@ import org.slf4j.LoggerFactory;
 import java.io.IOException;
 import java.util.SortedMap;
 import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+
+import static org.apache.flink.util.Preconditions.checkState;
 
 /** Maintains the pending checkpoints. */
 public class Checkpoints<T> implements AutoCloseable {
@@ -43,7 +48,10 @@ public class Checkpoints<T> implements AutoCloseable {
     private final FileSystem fileSystem;
     private final SupplierWithException<Path, IOException> pathSupplier;
 
-    private final TreeMap<Long, PendingCheckpoint> uncompletedCheckpoints = new TreeMap<>();
+    private final ConcurrentHashMap<Long, Tuple2<PendingCheckpoint, Boolean>>
+            uncompletedCheckpoints = new ConcurrentHashMap<>();
+
+    private final TreeMap<Long, PendingCheckpoint> sortedUncompletedCheckpoints = new TreeMap<>();
 
     public Checkpoints(
             TypeSerializer<T> typeSerializer,
@@ -51,27 +59,72 @@ public class Checkpoints<T> implements AutoCloseable {
             SupplierWithException<Path, IOException> pathSupplier) {
         this.typeSerializer = typeSerializer;
         this.fileSystem = fileSystem;
+        checkState(!fileSystem.isDistributedFS(), "Currently only local fs is supported");
         this.pathSupplier = pathSupplier;
     }
 
+    public TypeSerializer<T> getTypeSerializer() {
+        return typeSerializer;
+    }
+
+    public FileSystem getFileSystem() {
+        return fileSystem;
+    }
+
+    public SupplierWithException<Path, IOException> getPathSupplier() {
+        return pathSupplier;
+    }
+
     public void startLogging(long checkpointId, OperatorStateCheckpointOutputStream outputStream)
             throws IOException {
-        DataCacheWriter<T> dataCacheWriter =
-                new DataCacheWriter<>(typeSerializer, fileSystem, pathSupplier);
-        ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
-        uncompletedCheckpoints.put(
-                checkpointId, new PendingCheckpoint(dataCacheWriter, outputStream, snapshotLease));
+        Tuple2<PendingCheckpoint, Boolean> possibleCheckpoint =
+                uncompletedCheckpoints.computeIfAbsent(
+                        checkpointId,
+                        ignored -> {
+                            try {
+                                DataCacheWriter<T> dataCacheWriter =
+                                        new DataCacheWriter<>(
+                                                typeSerializer, fileSystem, pathSupplier);
+                                ResourceGuard.Lease snapshotLease = outputStream.acquireLease();
+                                return new Tuple2<>(
+                                        new PendingCheckpoint(
+                                                dataCacheWriter, outputStream, snapshotLease),
+                                        false);
+                            } catch (IOException e) {
+                                throw new FlinkRuntimeException(e);
+                            }
+                        });
+
+        // If canceled, return
+        if (possibleCheckpoint.f1) {
+            return;
+        }
+
+        sortedUncompletedCheckpoints.put(checkpointId, possibleCheckpoint.f0);
+    }
+
+    public void abort(long checkpointId) {
+        uncompletedCheckpoints.compute(
+                checkpointId,
+                (k, v) -> {
+                    if (v == null) {
+                        return new Tuple2<>(null, true);
+                    } else {
+                        v.f0.snapshotLease.close();
+                        return new Tuple2<>(v.f0, true);
+                    }
+                });
     }
 
     public void append(T element) throws IOException {
-        for (PendingCheckpoint pendingCheckpoint : uncompletedCheckpoints.values()) {
+        for (PendingCheckpoint pendingCheckpoint : sortedUncompletedCheckpoints.values()) {
             pendingCheckpoint.dataCacheWriter.addRecord(element);
         }
     }
 
     public void commitCheckpointsUntil(long checkpointId) {
         SortedMap<Long, PendingCheckpoint> completedCheckpoints =
-                uncompletedCheckpoints.headMap(checkpointId, true);
+                sortedUncompletedCheckpoints.headMap(checkpointId, true);
         completedCheckpoints
                 .values()
                 .forEach(
@@ -86,9 +139,13 @@ public class Checkpoints<T> implements AutoCloseable {
                                                         .getFinishSegments());
                                 pendingCheckpoint.checkpointOutputStream.startNewPartition();
                                 snapshot.writeTo(pendingCheckpoint.checkpointOutputStream);
+
+                                // Directly cleanup all the files since we are using the local fs.
+                                // TODO: support of the remote fs.
                                 pendingCheckpoint.dataCacheWriter.cleanup();
                             } catch (Exception e) {
                                 LOG.error("Failed to commit checkpoint until " + checkpointId, e);
+                                throw new FlinkRuntimeException(e);
                             } finally {
                                 pendingCheckpoint.snapshotLease.close();
                             }
@@ -99,7 +156,7 @@ public class Checkpoints<T> implements AutoCloseable {
 
     @Override
     public void close() {
-        uncompletedCheckpoints.forEach(
+        sortedUncompletedCheckpoints.forEach(
                 (checkpointId, pendingCheckpoint) -> {
                     pendingCheckpoint.snapshotLease.close();
                     try {
@@ -108,10 +165,12 @@ public class Checkpoints<T> implements AutoCloseable {
                         LOG.error("Failed to cleanup " + checkpointId, e);
                     }
                 });
+        sortedUncompletedCheckpoints.clear();
         uncompletedCheckpoints.clear();
     }
 
     private class PendingCheckpoint {
+
         final DataCacheWriter<T> dataCacheWriter;
 
         final OperatorStateCheckpointOutputStream checkpointOutputStream;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
new file mode 100644
index 0000000..e969c10
--- /dev/null
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/checkpoint/CheckpointsBroker.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.iteration.checkpoint;
+
+import org.apache.flink.statefun.flink.core.feedback.SubtaskFeedbackKey;
+
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * Hand offs the {@link Checkpoints} from the head operator to the tail operator so that the tail
+ * operator could decrease the reference count of the raw state when checkpoints are aborted. We
+ * could not count on the head operator since it would be blocked on closing the raw state when
+ * aborting the checkpoint. It also looks like a bug.
+ */
+public class CheckpointsBroker {
+
+    private static final CheckpointsBroker INSTANCE = new CheckpointsBroker();
+
+    private final ConcurrentHashMap<SubtaskFeedbackKey<?>, Checkpoints<?>> checkpointManagers =
+            new ConcurrentHashMap<>();
+
+    public static CheckpointsBroker get() {
+        return INSTANCE;
+    }
+
+    public <V> void setCheckpoints(SubtaskFeedbackKey<V> key, Checkpoints<V> checkpoints) {
+        checkpointManagers.put(key, checkpoints);
+    }
+
+    @SuppressWarnings({"unchecked"})
+    public <V> Checkpoints<V> getCheckpoints(SubtaskFeedbackKey<V> key) {
+        Objects.requireNonNull(key);
+        return (Checkpoints<V>) Objects.requireNonNull(checkpointManagers.get(key));
+    }
+
+    @SuppressWarnings("resource")
+    void removeChannel(SubtaskFeedbackKey<?> key) {
+        checkpointManagers.remove(key);
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
index 6e5a7b4..5306d2b 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperator.java
@@ -21,23 +21,46 @@ package org.apache.flink.iteration.operator;
 import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.core.fs.Path;
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
+import org.apache.flink.iteration.datacache.nonkeyed.DataCacheSnapshot;
 import org.apache.flink.iteration.operator.event.CoordinatorCheckpointEvent;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.HeadOperatorRecordProcessor;
+import org.apache.flink.iteration.operator.headprocessor.HeadOperatorState;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.operator.headprocessor.TerminatingHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.iteration.utils.ReflectionUtils;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
+import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferConsumerWithPartialRecordLength;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
+import org.apache.flink.runtime.io.network.partition.PrioritizedDeque;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
 import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
@@ -57,6 +80,9 @@ import org.apache.flink.util.FlinkRuntimeException;
 import org.apache.flink.util.OutputTag;
 
 import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.Executor;
 
@@ -112,6 +138,16 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     private HeadOperatorCheckpointAligner checkpointAligner;
 
+    // ------------- states -------------------
+
+    private ListState<Integer> parallelismState;
+
+    private ListState<Integer> statusState;
+
+    private ListState<HeadOperatorState> processorState;
+
+    private Checkpoints<IterationRecord<?>> checkpoints;
+
     public HeadOperator(
             IterationID iterationId,
             int feedbackIndex,
@@ -146,12 +182,87 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     public void initializeState(StateInitializationContext context) throws Exception {
         super.initializeState(context);
 
+        parallelismState =
+                context.getOperatorStateStore()
+                        .getUnionListState(
+                                new ListStateDescriptor<>("parallelism", IntSerializer.INSTANCE));
+        OperatorStateUtils.getUniqueElement(parallelismState, "parallelism")
+                .ifPresent(
+                        oldParallelism ->
+                                checkState(
+                                        oldParallelism
+                                                == getRuntimeContext()
+                                                        .getNumberOfParallelSubtasks(),
+                                        "The head operator is recovered with parallelism changed from "
+                                                + oldParallelism
+                                                + " to "
+                                                + getRuntimeContext()
+                                                        .getNumberOfParallelSubtasks()));
+
+        // Initialize the status and the record processor.
         processorContext = new ContextImpl();
-        status = HeadOperatorStatus.RUNNING;
-        recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+        statusState =
+                context.getOperatorStateStore()
+                        .getListState(new ListStateDescriptor<>("status", Integer.class));
+        status =
+                HeadOperatorStatus.values()[
+                        OperatorStateUtils.getUniqueElement(statusState, "status").orElse(0)];
+        if (status == HeadOperatorStatus.RUNNING) {
+            recordProcessor = new RegularHeadOperatorRecordProcessor(processorContext);
+        } else {
+            recordProcessor = new TerminatingHeadOperatorRecordProcessor();
+        }
+
+        // Recover the process state if exists.
+        processorState =
+                context.getOperatorStateStore()
+                        .getListState(
+                                new ListStateDescriptor<>(
+                                        "processorState", HeadOperatorState.class));
+        OperatorStateUtils.getUniqueElement(processorState, "processorState")
+                .ifPresent(
+                        headOperatorState ->
+                                recordProcessor.initializeState(
+                                        headOperatorState, context.getRawOperatorStateInputs()));
 
         checkpointAligner = new HeadOperatorCheckpointAligner();
 
+        // Initialize the checkpoints
+        Path dataCachePath =
+                OperatorUtils.getDataCachePath(
+                        getRuntimeContext().getTaskManagerRuntimeInfo().getConfiguration(),
+                        getContainingTask()
+                                .getEnvironment()
+                                .getIOManager()
+                                .getSpillingDirectoriesPaths());
+        this.checkpoints =
+                new Checkpoints<>(
+                        config.getTypeSerializerOut(getClass().getClassLoader()),
+                        dataCachePath.getFileSystem(),
+                        OperatorUtils.createDataCacheFileGenerator(
+                                dataCachePath, "header-cp", getOperatorConfig().getOperatorID()));
+        CheckpointsBroker.get()
+                .setCheckpoints(
+                        OperatorUtils.<IterationRecord<?>>createFeedbackKey(
+                                        iterationId, feedbackIndex)
+                                .withSubTaskIndex(
+                                        getRuntimeContext().getIndexOfThisSubtask(),
+                                        getRuntimeContext().getAttemptNumber()),
+                        checkpoints);
+
+        try {
+            for (StatePartitionStreamProvider rawStateInput : context.getRawOperatorStateInputs()) {
+                DataCacheSnapshot.replay(
+                        rawStateInput.getStream(),
+                        checkpoints.getTypeSerializer(),
+                        checkpoints.getFileSystem(),
+                        (record) ->
+                                recordProcessor.processFeedbackElement(new StreamRecord<>(record)));
+            }
+        } catch (Exception e) {
+            throw new FlinkRuntimeException("Failed to replay the records", e);
+        }
+
         // Here we register a mail
         registerFeedbackConsumer(
                 (Runnable runnable) -> {
@@ -171,18 +282,54 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
     @Override
     public void snapshotState(StateSnapshotContext context) throws Exception {
         super.snapshotState(context);
+
+        // Always clear the union list state before set value.
+        parallelismState.clear();
+        if (getRuntimeContext().getIndexOfThisSubtask() == 0) {
+            parallelismState.update(
+                    Collections.singletonList(getRuntimeContext().getNumberOfParallelSubtasks()));
+        }
+        statusState.update(Collections.singletonList(status.ordinal()));
+
+        HeadOperatorState currentProcessorState = recordProcessor.snapshotState();
+        if (currentProcessorState != null) {
+            processorState.update(Collections.singletonList(currentProcessorState));
+        } else {
+            processorState.clear();
+        }
+
+        if (status == HeadOperatorStatus.RUNNING) {
+            checkpoints.startLogging(
+                    context.getCheckpointId(), context.getRawOperatorStateOutput());
+        }
+
         checkpointAligner
                 .onStateSnapshot(context.getCheckpointId())
                 .forEach(this::processGloballyAlignedEvent);
     }
 
     @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        super.notifyCheckpointAborted(checkpointId);
+
+        checkpointAligner
+                .onCheckpointAborted(checkpointId)
+                .forEach(this::processGloballyAlignedEvent);
+    }
+
+    @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) throws Exception {
         recordProcessor.processElement(element);
     }
 
     @Override
     public void processFeedback(StreamRecord<IterationRecord<?>> iterationRecord) throws Exception {
+        if (iterationRecord.getValue().getType() == IterationRecord.Type.BARRIER) {
+            checkpoints.commitCheckpointsUntil(iterationRecord.getValue().getCheckpointId());
+            return;
+        }
+
+        checkpoints.append(iterationRecord.getValue());
         boolean terminated = recordProcessor.processFeedbackElement(iterationRecord);
         if (terminated) {
             checkState(status == HeadOperatorStatus.TERMINATING);
@@ -213,13 +360,65 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
 
     @Override
     public void endInput() throws Exception {
-        recordProcessor.processElement(
-                new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+        if (status == HeadOperatorStatus.RUNNING) {
+            recordProcessor.processElement(
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+        }
+
+        // Since we choose to block here, we could not continue to process the barriers received
+        // from the task inputs, which would block the precedent tasks from finishing since
+        // they need to complete their final checkpoint. This is a temporary solution to this issue
+        // that we will check the input channels, trigger all the checkpoints until we see
+        // the EndOfPartitionEvent.
+        checkState(getContainingTask().getEnvironment().getAllInputGates().length == 1);
+        checkState(
+                getContainingTask()
+                                .getEnvironment()
+                                .getAllInputGates()[0]
+                                .getNumberOfInputChannels()
+                        == 1);
+        InputChannel inputChannel =
+                getContainingTask().getEnvironment().getAllInputGates()[0].getChannel(0);
+
+        boolean endOfPartitionReceived = false;
+        long lastTriggerCheckpointId = 0;
+        while (!endOfPartitionReceived && status != HeadOperatorStatus.TERMINATED) {
+            mailboxExecutor.tryYield();
+            Thread.sleep(200);
+
+            List<AbstractEvent> events = parseInputChannelEvents(inputChannel);
+
+            for (AbstractEvent event : events) {
+                if (event instanceof CheckpointBarrier) {
+                    CheckpointBarrier barrier = (CheckpointBarrier) event;
+                    if (barrier.getId() > lastTriggerCheckpointId) {
+                        getContainingTask()
+                                .triggerCheckpointAsync(
+                                        new CheckpointMetaData(
+                                                barrier.getId(), barrier.getTimestamp()),
+                                        barrier.getCheckpointOptions());
+                        lastTriggerCheckpointId = barrier.getId();
+                    }
+
+                } else if (event instanceof EndOfPartitionEvent) {
+                    endOfPartitionReceived = true;
+                }
+            }
+        }
+
+        // By here we could step into the normal loop.
         while (status != HeadOperatorStatus.TERMINATED) {
             mailboxExecutor.yield();
         }
     }
 
+    @Override
+    public void close() throws Exception {
+        if (checkpoints != null) {
+            checkpoints.close();
+        }
+    }
+
     private void registerFeedbackConsumer(Executor mailboxExecutor) {
         int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
         int attemptNum = getRuntimeContext().getAttemptNumber();
@@ -232,6 +431,48 @@ public class HeadOperator extends AbstractStreamOperator<IterationRecord<?>>
         OperatorUtils.registerFeedbackConsumer(channel, this, mailboxExecutor);
     }
 
+    private List<AbstractEvent> parseInputChannelEvents(InputChannel inputChannel)
+            throws Exception {
+        List<AbstractEvent> events = new ArrayList<>();
+        if (inputChannel instanceof RemoteInputChannel) {
+            Class<?> seqBufferClass =
+                    Class.forName(
+                            "org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel$SequenceBuffer");
+            PrioritizedDeque<?> queue =
+                    ReflectionUtils.getFieldValue(
+                            inputChannel, RemoteInputChannel.class, "receivedBuffers");
+            for (Object sequenceBuffer : queue) {
+                Buffer buffer =
+                        ReflectionUtils.getFieldValue(sequenceBuffer, seqBufferClass, "buffer");
+                if (!buffer.isBuffer()) {
+                    events.add(EventSerializer.fromBuffer(buffer, getClass().getClassLoader()));
+                }
+            }
+        } else if (inputChannel instanceof LocalInputChannel) {
+            PipelinedSubpartitionView subpartitionView =
+                    ReflectionUtils.getFieldValue(
+                            inputChannel, LocalInputChannel.class, "subpartitionView");
+            PipelinedSubpartition pipelinedSubpartition =
+                    ReflectionUtils.getFieldValue(
+                            subpartitionView, PipelinedSubpartitionView.class, "parent");
+            PrioritizedDeque<BufferConsumerWithPartialRecordLength> queue =
+                    ReflectionUtils.getFieldValue(
+                            pipelinedSubpartition, PipelinedSubpartition.class, "buffers");
+            for (BufferConsumerWithPartialRecordLength bufferConsumer : queue) {
+                if (!bufferConsumer.getBufferConsumer().isBuffer()) {
+                    events.add(
+                            EventSerializer.fromBuffer(
+                                    bufferConsumer.getBufferConsumer().copy().build(),
+                                    getClass().getClassLoader()));
+                }
+            }
+        } else {
+            LOG.warn("Unknown input channel type: " + inputChannel);
+        }
+
+        return events;
+    }
+
     @VisibleForTesting
     public OperatorEventGateway getOperatorEventGateway() {
         return operatorEventGateway;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
index b4f4215..9b83a7d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorCheckpointAligner.java
@@ -24,6 +24,7 @@ import org.apache.flink.util.function.RunnableWithException;
 
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Map;
 import java.util.Optional;
 import java.util.TreeMap;
 
@@ -40,6 +41,8 @@ class HeadOperatorCheckpointAligner {
 
     private long latestCheckpointFromCoordinator;
 
+    private long latestAbortedCheckpoint;
+
     HeadOperatorCheckpointAligner() {
         this.checkpointAlignmments = new TreeMap<>();
     }
@@ -58,6 +61,13 @@ class HeadOperatorCheckpointAligner {
     void coordinatorNotify(CoordinatorCheckpointEvent checkpointEvent) {
         checkState(checkpointEvent.getCheckpointId() > latestCheckpointFromCoordinator);
         latestCheckpointFromCoordinator = checkpointEvent.getCheckpointId();
+
+        // Do nothing if later checkpoint is aborted. In this case there should not be
+        // the notification from the task side.
+        if (latestCheckpointFromCoordinator <= latestAbortedCheckpoint) {
+            return;
+        }
+
         CheckpointAlignment checkpointAlignment =
                 checkpointAlignmments.computeIfAbsent(
                         checkpointEvent.getCheckpointId(),
@@ -86,6 +96,22 @@ class HeadOperatorCheckpointAligner {
         return checkpointAlignment.pendingGlobalEvents;
     }
 
+    List<GloballyAlignedEvent> onCheckpointAborted(long checkpointId) {
+        // Here we need to abort all the checkpoints <= notified checkpoint id.
+        checkState(checkpointId > latestAbortedCheckpoint);
+        latestAbortedCheckpoint = checkpointId;
+
+        Map<Long, CheckpointAlignment> abortedAlignments =
+                checkpointAlignmments.headMap(latestAbortedCheckpoint, true);
+        List<GloballyAlignedEvent> events = new ArrayList<>();
+        abortedAlignments
+                .values()
+                .forEach(alignment -> events.addAll(alignment.pendingGlobalEvents));
+        abortedAlignments.clear();
+
+        return events;
+    }
+
     private static class CheckpointAlignment {
 
         final List<GloballyAlignedEvent> pendingGlobalEvents;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
index 2bd1e38..5040506 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/HeadOperatorFactory.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorCoordinator;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperatorFactory;
+import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.CoordinatedOperatorFactory;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -132,4 +133,11 @@ public class HeadOperatorFactory extends AbstractStreamOperatorFactory<Iteration
         // We need it to be yielding operator factory to disable chaining,
         // but we cannot use the given mailbox here since it has bugs.
     }
+
+    @Override
+    public ChainingStrategy getChainingStrategy() {
+        // We could not allow the head operator chaining with the previous operator since
+        // the special treatment in endInput.
+        return ChainingStrategy.HEAD;
+    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
similarity index 51%
copy from flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
copy to flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
index 861f481..fc4cd07 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/OperatorStateUtils.java
@@ -16,7 +16,27 @@
  * limitations under the License.
  */
 
-package org.apache.flink.iteration.operator.headprocessor;
+package org.apache.flink.iteration.operator;
 
-/** The state entry for the head operator. */
-public class HeadOperatorState {}
+import org.apache.flink.api.common.state.ListState;
+
+import java.util.Iterator;
+import java.util.Optional;
+
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Utility to deal with the states inside the operator. */
+public class OperatorStateUtils {
+
+    public static <T> Optional<T> getUniqueElement(ListState<T> listState, String stateName)
+            throws Exception {
+        Iterator<T> iterator = listState.get().iterator();
+        if (!iterator.hasNext()) {
+            return Optional.empty();
+        }
+
+        T result = iterator.next();
+        checkState(!iterator.hasNext(), "The state " + stateName + " has more that one elements");
+        return Optional.of(result);
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
index a58155e..2e26142 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/TailOperator.java
@@ -20,6 +20,8 @@ package org.apache.flink.iteration.operator;
 
 import org.apache.flink.iteration.IterationID;
 import org.apache.flink.iteration.IterationRecord;
+import org.apache.flink.iteration.checkpoint.Checkpoints;
+import org.apache.flink.iteration.checkpoint.CheckpointsBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannel;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackChannelBroker;
 import org.apache.flink.statefun.flink.core.feedback.FeedbackKey;
@@ -90,6 +92,31 @@ public class TailOperator extends AbstractStreamOperator<Void>
         recordConsumer.accept(streamRecord);
     }
 
+    @Override
+    public void prepareSnapshotPreBarrier(long checkpointId) throws Exception {
+        super.prepareSnapshotPreBarrier(checkpointId);
+        channel.put(new StreamRecord<>(IterationRecord.newBarrier(checkpointId)));
+    }
+
+    @Override
+    public void notifyCheckpointAborted(long checkpointId) throws Exception {
+        super.notifyCheckpointAborted(checkpointId);
+
+        // TODO: Unfortunately, we have to rely on the tail operator to help
+        // abort the checkpoint since the task thread of the head operator
+        // might get blocked due to not be able to close the raw state files.
+        // We would try to fix it in the Flink side in the future.
+        SubtaskFeedbackKey<?> key =
+                OperatorUtils.createFeedbackKey(iterationId, feedbackIndex)
+                        .withSubTaskIndex(
+                                getRuntimeContext().getIndexOfThisSubtask(),
+                                getRuntimeContext().getAttemptNumber());
+        Checkpoints<?> checkpoints = CheckpointsBroker.get().getCheckpoints(key);
+        if (checkpoints != null) {
+            checkpoints.abort(checkpointId);
+        }
+    }
+
     private void processIfObjectReuseEnabled(StreamRecord<IterationRecord<?>> record) {
         // Since the record would be reused, we have to clone a new one
         IterationRecord<?> cloned = record.getValue().clone();
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
index 2c01c33..f4d6e2d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/HeadOperatorCoordinator.java
@@ -63,10 +63,16 @@ public class HeadOperatorCoordinator implements OperatorCoordinator, SharedProgr
     }
 
     @Override
-    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {}
+    public void resetToCheckpoint(long checkpointId, @Nullable byte[] bytes) {
+        for (int i = 0; i < context.currentParallelism(); ++i) {
+            sharedProgressAligner.removeProgressInfo(context.getOperatorId());
+        }
+    }
 
     @Override
-    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {}
+    public void subtaskFailed(int subtaskIndex, @Nullable Throwable throwable) {
+        sharedProgressAligner.removeProgressInfo(context.getOperatorId(), subtaskIndex);
+    }
 
     @Override
     public void handleEventFromOperator(int subtaskIndex, OperatorEvent operatorEvent) {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
index 9c2fd43..4c2204d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/coordinator/SharedProgressAligner.java
@@ -182,6 +182,24 @@ public class SharedProgressAligner {
                 checkpointId);
     }
 
+    public void removeProgressInfo(OperatorID operatorId) {
+        runInEventLoop(
+                () -> statusByEpoch.values().forEach(status -> status.remove(operatorId)),
+                "remove the progress information for {}",
+                operatorId);
+    }
+
+    public void removeProgressInfo(OperatorID operatorId, int subtaskIndex) {
+        runInEventLoop(
+                () ->
+                        statusByEpoch
+                                .values()
+                                .forEach(status -> status.remove(operatorId, subtaskIndex)),
+                "remove the progress information for {}-{}",
+                operatorId,
+                subtaskIndex);
+    }
+
     private void runInEventLoop(
             ThrowingRunnable<Throwable> action,
             String actionName,
@@ -234,6 +252,16 @@ public class SharedProgressAligner {
             return reportedSubtasks.size() == totalHeadParallelism;
         }
 
+        public void remove(OperatorID operatorID) {
+            reportedSubtasks
+                    .entrySet()
+                    .removeIf(entry -> entry.getKey().getOperatorId().equals(operatorID));
+        }
+
+        public void remove(OperatorID operatorID, int subtaskIndex) {
+            reportedSubtasks.remove(new OperatorInstanceID(subtaskIndex, operatorID));
+        }
+
         public boolean isTerminated() {
             checkState(
                     reportedSubtasks.size() == totalHeadParallelism,
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
index 78132c3..26d397d 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorRecordProcessor.java
@@ -22,14 +22,18 @@ import org.apache.flink.api.common.TaskInfo;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.HeadOperator;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.OutputTag;
 
+import javax.annotation.Nullable;
+
 /** The component to actually deal with the event received in the {@link HeadOperator}. */
 public interface HeadOperatorRecordProcessor {
 
-    void initializeState(HeadOperatorState headOperatorState) throws Exception;
+    void initializeState(
+            HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates);
 
     void processElement(StreamRecord<IterationRecord<?>> record);
 
@@ -37,6 +41,7 @@ public interface HeadOperatorRecordProcessor {
 
     boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent);
 
+    @Nullable
     HeadOperatorState snapshotState();
 
     /** The context for {@link HeadOperatorRecordProcessor}. */
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
index 861f481..996d929 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/HeadOperatorState.java
@@ -18,5 +18,35 @@
 
 package org.apache.flink.iteration.operator.headprocessor;
 
+import java.util.Map;
+
 /** The state entry for the head operator. */
-public class HeadOperatorState {}
+public class HeadOperatorState {
+
+    private Map<Integer, Long> numFeedbackRecordsEachRound;
+
+    private int latestRoundAligned;
+
+    private int latestRoundGloballyAligned;
+
+    public HeadOperatorState(
+            Map<Integer, Long> numFeedbackRecordsEachRound,
+            int latestRoundAligned,
+            int latestRoundGloballyAligned) {
+        this.numFeedbackRecordsEachRound = numFeedbackRecordsEachRound;
+        this.latestRoundAligned = latestRoundAligned;
+        this.latestRoundGloballyAligned = latestRoundGloballyAligned;
+    }
+
+    public Map<Integer, Long> getNumFeedbackRecordsEachRound() {
+        return numFeedbackRecordsEachRound;
+    }
+
+    public int getLatestRoundAligned() {
+        return latestRoundAligned;
+    }
+
+    public int getLatestRoundGloballyAligned() {
+        return latestRoundGloballyAligned;
+    }
+}
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
index f1a6b0f..107a233 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/RegularHeadOperatorRecordProcessor.java
@@ -22,6 +22,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 import org.slf4j.Logger;
@@ -30,6 +31,9 @@ import org.slf4j.LoggerFactory;
 import java.util.HashMap;
 import java.util.Map;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
 /**
  * Processes the event before we received the terminated global aligned event from the coordinator.
  */
@@ -40,26 +44,46 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
 
     private final Context headOperatorContext;
 
-    private final StreamRecord<IterationRecord<?>> reusable;
-
     private final Map<Integer, Long> numFeedbackRecordsPerEpoch;
 
     private final String senderId;
 
+    private int latestRoundAligned;
+
+    private int latestRoundGloballyAligned;
+
     public RegularHeadOperatorRecordProcessor(Context headOperatorContext) {
         this.headOperatorContext = headOperatorContext;
 
-        this.reusable = new StreamRecord<>(null);
         this.numFeedbackRecordsPerEpoch = new HashMap<>();
 
         this.senderId =
                 OperatorUtils.getUniqueSenderId(
                         headOperatorContext.getStreamConfig().getOperatorID(),
                         headOperatorContext.getTaskInfo().getIndexOfThisSubtask());
+
+        this.latestRoundAligned = -1;
+        this.latestRoundGloballyAligned = -1;
     }
 
     @Override
-    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+    public void initializeState(
+            HeadOperatorState headOperatorState, Iterable<StatePartitionStreamProvider> rawStates) {
+        checkArgument(headOperatorState != null, "The initialized state should not be null");
+
+        numFeedbackRecordsPerEpoch.putAll(headOperatorState.getNumFeedbackRecordsEachRound());
+        latestRoundAligned = headOperatorState.getLatestRoundAligned();
+        latestRoundGloballyAligned = headOperatorState.getLatestRoundGloballyAligned();
+
+        // If the only round not fully aligned is round 0, then wait till endOfInput in
+        // case the input is changed.
+        if (!(latestRoundAligned == 0 && latestRoundGloballyAligned == -1)) {
+            for (int i = latestRoundGloballyAligned + 1; i <= latestRoundAligned; ++i) {
+                headOperatorContext.updateEpochToCoordinator(
+                        i, numFeedbackRecordsPerEpoch.getOrDefault(i, 0L));
+            }
+        }
+    }
 
     @Override
     public void processElement(StreamRecord<IterationRecord<?>> element) {
@@ -81,22 +105,34 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
     @Override
     public boolean onGloballyAligned(GloballyAlignedEvent globallyAlignedEvent) {
         LOG.info("Received global event {}", globallyAlignedEvent);
-
-        reusable.replace(
-                IterationRecord.newEpochWatermark(
-                        globallyAlignedEvent.isTerminated()
-                                ? Integer.MAX_VALUE
-                                : globallyAlignedEvent.getEpoch(),
-                        senderId),
-                0);
-        headOperatorContext.broadcastOutput(reusable);
-
+        checkState(
+                (globallyAlignedEvent.getEpoch() == 0 && latestRoundGloballyAligned == 0)
+                        || globallyAlignedEvent.getEpoch() > latestRoundGloballyAligned,
+                String.format(
+                        "Receive unexpected global aligned event, latest = %d, this one = %d",
+                        latestRoundGloballyAligned, globallyAlignedEvent.getEpoch()));
+
+        StreamRecord<IterationRecord<?>> record =
+                new StreamRecord<>(
+                        IterationRecord.newEpochWatermark(
+                                globallyAlignedEvent.isTerminated()
+                                        ? Integer.MAX_VALUE
+                                        : globallyAlignedEvent.getEpoch(),
+                                senderId),
+                        0);
+        headOperatorContext.broadcastOutput(record);
+
+        latestRoundGloballyAligned =
+                Math.max(globallyAlignedEvent.getEpoch(), latestRoundGloballyAligned);
         return globallyAlignedEvent.isTerminated();
     }
 
     @Override
     public HeadOperatorState snapshotState() {
-        return new HeadOperatorState();
+        return new HeadOperatorState(
+                new HashMap<>(numFeedbackRecordsPerEpoch),
+                latestRoundAligned,
+                latestRoundGloballyAligned);
     }
 
     @VisibleForTesting
@@ -104,18 +140,50 @@ public class RegularHeadOperatorRecordProcessor implements HeadOperatorRecordPro
         return numFeedbackRecordsPerEpoch;
     }
 
+    @VisibleForTesting
+    public int getLatestRoundAligned() {
+        return latestRoundAligned;
+    }
+
+    @VisibleForTesting
+    public int getLatestRoundGloballyAligned() {
+        return latestRoundGloballyAligned;
+    }
+
     private void processRecord(StreamRecord<IterationRecord<?>> iterationRecord) {
         switch (iterationRecord.getValue().getType()) {
             case RECORD:
-                reusable.replace(iterationRecord.getValue(), iterationRecord.getTimestamp());
-                headOperatorContext.output(reusable);
+                headOperatorContext.output(iterationRecord);
                 break;
             case EPOCH_WATERMARK:
                 LOG.info("Head Received epoch watermark {}", iterationRecord.getValue().getEpoch());
-                headOperatorContext.updateEpochToCoordinator(
-                        iterationRecord.getValue().getEpoch(),
-                        numFeedbackRecordsPerEpoch.getOrDefault(
-                                iterationRecord.getValue().getEpoch(), 0L));
+
+                boolean needNotifyCoordinator = false;
+                if (iterationRecord.getValue().getEpoch() == 0) {
+                    if (latestRoundAligned <= 0) {
+                        needNotifyCoordinator = true;
+                    }
+                } else {
+                    checkState(
+                            iterationRecord.getValue().getEpoch() > latestRoundAligned,
+                            String.format(
+                                    "Unexpected epoch watermark: latest = %d, this one = %d",
+                                    latestRoundAligned, iterationRecord.getValue().getEpoch()));
+                    headOperatorContext.updateEpochToCoordinator(
+                            iterationRecord.getValue().getEpoch(),
+                            numFeedbackRecordsPerEpoch.getOrDefault(
+                                    iterationRecord.getValue().getEpoch(), 0L));
+                }
+
+                if (needNotifyCoordinator) {
+                    headOperatorContext.updateEpochToCoordinator(
+                            iterationRecord.getValue().getEpoch(),
+                            numFeedbackRecordsPerEpoch.getOrDefault(
+                                    iterationRecord.getValue().getEpoch(), 0L));
+                }
+
+                latestRoundAligned =
+                        Math.max(iterationRecord.getValue().getEpoch(), latestRoundAligned);
                 break;
         }
     }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
index c5377e5..f18bfa3 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/headprocessor/TerminatingHeadOperatorRecordProcessor.java
@@ -20,6 +20,7 @@ package org.apache.flink.iteration.operator.headprocessor;
 
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
+import org.apache.flink.runtime.state.StatePartitionStreamProvider;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.FlinkRuntimeException;
 
@@ -30,7 +31,9 @@ import org.apache.flink.util.FlinkRuntimeException;
 public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecordProcessor {
 
     @Override
-    public void initializeState(HeadOperatorState headOperatorState) throws Exception {}
+    public void initializeState(
+            HeadOperatorState headOperatorState,
+            Iterable<StatePartitionStreamProvider> rawStates) {}
 
     @Override
     public void processElement(StreamRecord<IterationRecord<?>> record) {
@@ -55,6 +58,6 @@ public class TerminatingHeadOperatorRecordProcessor implements HeadOperatorRecor
 
     @Override
     public HeadOperatorState snapshotState() {
-        return new HeadOperatorState();
+        return null;
     }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index 9fa0d28..3cd4bf2 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -79,6 +79,44 @@ public class IterationConstructionTest extends TestLogger {
     }
 
     @Test
+    public void testNotChainingHeadOperator() {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(4);
+        DataStream<Integer> variableSource =
+                env.addSource(new DraftExecutionEnvironment.EmptySource<Integer>() {})
+                        .name("Variable")
+                        .map(x -> x)
+                        .name("map")
+                        .setParallelism(2);
+        DataStreamList result =
+                Iterations.iterateUnboundedStreams(
+                        DataStreamList.of(variableSource),
+                        DataStreamList.of(),
+                        ((variableStreams, dataStreams) ->
+                                new IterationBodyResult(variableStreams, dataStreams)));
+
+        JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+        List<String> expectedVertexNames =
+                Arrays.asList(
+                        /* 0 */ "Source: Variable",
+                        /* 1 */ "map -> input-map",
+                        /* 2 */ "head-map",
+                        /* 3 */ "tail-head-map");
+        List<Integer> expectedParallelisms = Arrays.asList(4, 2, 2, 2);
+
+        List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+        assertEquals(
+                expectedVertexNames,
+                vertices.stream().map(JobVertex::getName).collect(Collectors.toList()));
+        assertEquals(
+                expectedParallelisms,
+                vertices.stream().map(JobVertex::getParallelism).collect(Collectors.toList()));
+        assertNotNull(vertices.get(2).getCoLocationGroup());
+        assertSame(vertices.get(2).getCoLocationGroup(), vertices.get(3).getCoLocationGroup());
+    }
+
+    @Test
     public void testUnboundedIteration() {
         StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
         DataStream<Integer> variableSource1 =
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
index e6d1e26..5603cc9 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/HeadOperatorTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.iteration.operator.event.GloballyAlignedEvent;
 import org.apache.flink.iteration.operator.event.SubtaskAlignedEvent;
 import org.apache.flink.iteration.operator.headprocessor.RegularHeadOperatorRecordProcessor;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
+import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
@@ -43,6 +44,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.FlinkException;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TestLogger;
@@ -56,7 +58,10 @@ import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.BlockingQueue;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
@@ -157,7 +162,10 @@ public class HeadOperatorTest extends TestLogger {
 
                                             while (RecordingHeadOperatorFactory.latestHeadOperator
                                                             .getStatus()
-                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {}
+                                                    == HeadOperator.HeadOperatorStatus.RUNNING) {
+                                                Thread.sleep(500);
+                                            }
+
                                             putFeedbackRecords(
                                                     iterationId,
                                                     IterationRecord.newEpochWatermark(
@@ -312,6 +320,452 @@ public class HeadOperatorTest extends TestLogger {
                 });
     }
 
+    @Test
+    public void testSnapshotAndRestoreBeforeRoundZeroFinish() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            -1,
+                            -1);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneNotAligned() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+                            // Simulates endOfInputs, but not block the main thread.
+                            harness.processElement(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(0, "fake")));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            0,
+                            -1);
+
+                    // Simulates endOfInputs, but not block the main thread.
+                    harness.processElement(
+                            new StreamRecord<>(IterationRecord.newEpochWatermark(0, "fake")));
+                    assertEquals(
+                            Collections.singletonList(new SubtaskAlignedEvent(0, 0, false)),
+                            new ArrayList<>(
+                                    ((RecordingOperatorEventGateway)
+                                                    RecordingHeadOperatorFactory.latestHeadOperator
+                                                            .getOperatorEventGateway())
+                                            .operatorEvents));
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreAfterRoundZeroFinishAndRoundOneAligned() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            harness.processElement(
+                                    new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+
+                            // Simulates endOfInputs, but not block the main thread.
+                            harness.processElement(
+                                    new StreamRecord<>(
+                                            IterationRecord.newEpochWatermark(0, "fake")));
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(0, false));
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(100, 1), null);
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(1, "tail"),
+                                    null);
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.singletonList(new SubtaskAlignedEvent(1, 1, false)),
+                            Collections.singletonMap(1, 1L),
+                            1,
+                            0);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testSnapshotAndRestoreWithFeedbackRecords() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(4, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(4, false));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(100, 5), null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(101, 5), null);
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(102, 5), null);
+                            putFeedbackRecords(
+                                    iterationId, IterationRecord.newRecord(103, 6), null);
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Arrays.asList(
+                                    new StreamRecord<>(IterationRecord.newRecord(101, 5)),
+                                    new StreamRecord<>(IterationRecord.newRecord(102, 5)),
+                                    new StreamRecord<>(IterationRecord.newRecord(103, 6))),
+                            /* The one before checkpoint and the two after checkpoint */
+                            Collections.singletonList(new SubtaskAlignedEvent(5, 3, false)),
+                            new HashMap<Integer, Long>() {
+                                {
+                                    this.put(5, 3L);
+                                    this.put(6, 1L);
+                                }
+                            },
+                            5,
+                            4);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testCheckpointBeforeTerminated() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(4, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(4, false));
+                            harness.processAll();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(5, true));
+                            harness.processAll();
+
+                            putFeedbackRecords(iterationId, IterationRecord.newBarrier(2), null);
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.RUNNING,
+                            Collections.emptyList(),
+                            Collections.singletonList(new SubtaskAlignedEvent(5, 0, false)),
+                            Collections.emptyMap(),
+                            5,
+                            4);
+                    return null;
+                });
+    }
+
+    @Test
+    public void testCheckpointAfterTerminating() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        TaskStateSnapshot taskStateSnapshot =
+                createHarnessAndRun(
+                        iterationId,
+                        operatorId,
+                        null,
+                        harness -> {
+                            harness.getTaskStateManager().getWaitForReportLatch().reset();
+
+                            putFeedbackRecords(
+                                    iterationId,
+                                    IterationRecord.newEpochWatermark(5, "tail"),
+                                    null);
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new GloballyAlignedEvent(5, true));
+                            harness.processAll();
+
+                            dispatchOperatorEvent(
+                                    harness, operatorId, new CoordinatorCheckpointEvent(2));
+                            harness.getStreamTask()
+                                    .triggerCheckpointAsync(
+                                            new CheckpointMetaData(2, 1000),
+                                            CheckpointOptions.alignedNoTimeout(
+                                                    CheckpointType.CHECKPOINT,
+                                                    CheckpointStorageLocationReference
+                                                            .getDefault()));
+                            harness.processAll();
+
+                            harness.getTaskStateManager().getWaitForReportLatch().await();
+                            return harness.getTaskStateManager()
+                                    .getLastJobManagerTaskStateSnapshot();
+                        });
+        assertNotNull(taskStateSnapshot);
+        cleanupFeedbackChannel(iterationId);
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                taskStateSnapshot,
+                harness -> {
+                    checkRestoredOperatorState(
+                            harness,
+                            HeadOperator.HeadOperatorStatus.TERMINATING,
+                            Collections.emptyList(),
+                            Collections.emptyList(),
+                            Collections.emptyMap(),
+                            -1,
+                            -1);
+
+                    putFeedbackRecords(
+                            iterationId,
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE + 1, "tail"),
+                            null);
+                    harness.processEvent(EndOfData.INSTANCE);
+                    harness.finishProcessing();
+
+                    return null;
+                });
+    }
+
+    @Test(timeout = 20000)
+    public void testTailAbortPendingCheckpointIfHeadBlocked() throws Exception {
+        IterationID iterationId = new IterationID();
+        OperatorID operatorId = new OperatorID();
+
+        createHarnessAndRun(
+                iterationId,
+                operatorId,
+                null,
+                harness -> {
+                    harness.processElement(new StreamRecord<>(IterationRecord.newRecord(100, 0)));
+                    dispatchOperatorEvent(harness, operatorId, new CoordinatorCheckpointEvent(2));
+                    harness.getStreamTask()
+                            .triggerCheckpointAsync(
+                                    new CheckpointMetaData(2, 1000),
+                                    CheckpointOptions.alignedNoTimeout(
+                                            CheckpointType.CHECKPOINT,
+                                            CheckpointStorageLocationReference.getDefault()));
+                    harness.processAll();
+
+                    putFeedbackRecords(iterationId, IterationRecord.newRecord(100, 1), null);
+                    harness.processAll();
+
+                    // Simulates the tail operators help to abort the checkpoint
+                    CompletableFuture<Void> supplier =
+                            CompletableFuture.supplyAsync(
+                                    () -> {
+                                        try {
+                                            // Slightly postpone the execution till the head
+                                            // operator get blocked.
+                                            Thread.sleep(2000);
+
+                                            OneInputStreamOperatorTestHarness<
+                                                            IterationRecord<?>, Void>
+                                                    testHarness =
+                                                            new OneInputStreamOperatorTestHarness<>(
+                                                                    new TailOperator(
+                                                                            iterationId, 0));
+                                            testHarness.open();
+
+                                            testHarness.getOperator().notifyCheckpointAborted(2);
+                                        } catch (Exception e) {
+                                            throw new CompletionException(e);
+                                        }
+
+                                        return null;
+                                    });
+
+                    harness.getStreamTask().notifyCheckpointAbortAsync(2, 0);
+                    harness.processAll();
+
+                    supplier.get();
+
+                    return null;
+                });
+    }
+
     private <T> T createHarnessAndRun(
             IterationID iterationId,
             OperatorID operatorId,
@@ -371,6 +825,46 @@ public class HeadOperatorTest extends TestLogger {
                         : new StreamRecord<>(record, timestamp));
     }
 
+    private static void checkRestoredOperatorState(
+            StreamTaskMailboxTestHarness<?> harness,
+            HeadOperator.HeadOperatorStatus expectedStatus,
+            List<Object> expectedOutput,
+            List<OperatorEvent> expectedOperatorEvents,
+            Map<Integer, Long> expectedNumFeedbackRecords,
+            int expectedLastAligned,
+            int expectedLastGloballyAligned) {
+        HeadOperator headOperator = RecordingHeadOperatorFactory.latestHeadOperator;
+        assertEquals(expectedStatus, headOperator.getStatus());
+        assertEquals(expectedOutput, new ArrayList<>(harness.getOutput()));
+        RecordingOperatorEventGateway eventGateway =
+                (RecordingOperatorEventGateway) headOperator.getOperatorEventGateway();
+        assertEquals(expectedOperatorEvents, new ArrayList<>(eventGateway.operatorEvents));
+
+        if (expectedStatus == HeadOperator.HeadOperatorStatus.RUNNING) {
+            RegularHeadOperatorRecordProcessor recordProcessor =
+                    (RegularHeadOperatorRecordProcessor) headOperator.getRecordProcessor();
+            assertEquals(
+                    expectedNumFeedbackRecords, recordProcessor.getNumFeedbackRecordsPerEpoch());
+            assertEquals(expectedLastAligned, recordProcessor.getLatestRoundAligned());
+            assertEquals(
+                    expectedLastGloballyAligned, recordProcessor.getLatestRoundGloballyAligned());
+        }
+    }
+
+    /**
+     * We have to manually cleanup the feedback channel due to not be able to set the attempt
+     * number.
+     */
+    private static void cleanupFeedbackChannel(IterationID iterationId) {
+        FeedbackChannel<StreamRecord<IterationRecord<?>>> feedbackChannel =
+                FeedbackChannelBroker.get()
+                        .getChannel(
+                                OperatorUtils.<StreamRecord<IterationRecord<?>>>createFeedbackKey(
+                                                iterationId, 0)
+                                        .withSubTaskIndex(0, 0));
+        feedbackChannel.close();
+    }
+
     private static class RecordingOperatorEventGateway implements OperatorEventGateway {
 
         final BlockingQueue<OperatorEvent> operatorEvents = new LinkedBlockingQueue<>();