You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/03 13:57:45 UTC

[flink-ml] 07/08: [FLINK-24655][iteration] Do not rely on the precedent tasks to insert epoch watermark

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 9172ec762c797a697e71c17b401a94c3222dc506
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Thu Oct 7 10:47:30 2021 +0800

    [FLINK-24655][iteration] Do not rely on the precedent tasks to insert epoch watermark
    
    If the finished and failover, then it would be skipped and would not
    insert the epoch watermark again.
---
 .../org/apache/flink/iteration/Iterations.java     | 29 +++++++++---------
 .../operator/AbstractWrapperOperator.java          | 10 ++++++-
 .../flink/iteration/operator/InputOperator.java    | 20 ++-----------
 .../flink/iteration/operator/ReplayOperator.java   | 28 +++++++++++++++---
 .../allround/AbstractAllRoundWrapperOperator.java  | 21 +++++++------
 .../MultipleInputAllRoundWrapperOperator.java      |  2 ++
 .../allround/TwoInputAllRoundWrapperOperator.java  |  2 ++
 .../OperatorEpochWatermarkTracker.java             | 24 +++++++++++++--
 .../flink/iteration/IterationConstructionTest.java |  2 +-
 .../iteration/operator/InputOperatorTest.java      | 34 +---------------------
 .../iteration/operator/ReplayOperatorTest.java     | 21 +++++++------
 .../OperatorEpochWatermarkTrackerTest.java         | 17 +++++++++++
 12 files changed, 113 insertions(+), 97 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
index b334da2..9c9fa42 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/Iterations.java
@@ -35,7 +35,7 @@ import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.transformations.OneInputTransformation;
 import org.apache.flink.util.Collector;
 
@@ -197,7 +197,7 @@ public class Iterations {
                         .stream()
                         .mapToInt(i -> i)
                         .sum();
-        DataStreamList initVariableInputs = addInputs(initVariableStreams, false);
+        DataStreamList initVariableInputs = addInputs(initVariableStreams);
         DataStreamList headStreams =
                 addHeads(
                         initVariableStreams,
@@ -207,7 +207,7 @@ public class Iterations {
                         false,
                         0);
 
-        DataStreamList dataStreamInputs = addInputs(dataStreams, true);
+        DataStreamList dataStreamInputs = addInputs(dataStreams);
         if (replayedDataStreamIndices.size() > 0) {
             dataStreamInputs =
                     addReplayer(
@@ -293,13 +293,13 @@ public class Iterations {
             // Notes that the HeadOperator would broadcast the globally aligned events,
             // thus the operator does not require emit to the sideoutput specially.
             DataStream<?> replayedInput =
-                    ((SingleOutputStreamOperator<IterationRecord<?>>) firstHeadStream)
-                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
-                            .map(x -> x, dataStreamInputs.get(i).getType())
-                            .setParallelism(firstHeadStream.getParallelism())
-                            .name("signal-change-typeinfo")
-                            .broadcast()
-                            .union(dataStreamInputs.get(i))
+                    dataStreamInputs
+                            .get(i)
+                            .connect(
+                                    ((SingleOutputStreamOperator<IterationRecord<?>>)
+                                                    firstHeadStream)
+                                            .getSideOutput(HeadOperator.ALIGN_NOTIFY_OUTPUT_TAG)
+                                            .broadcast())
                             .transform(
                                     "Replayer-"
                                             + originalDataStreams
@@ -307,7 +307,7 @@ public class Iterations {
                                                     .getTransformation()
                                                     .getName(),
                                     dataStreamInputs.get(i).getType(),
-                                    (OneInputStreamOperator) new ReplayOperator<>())
+                                    (TwoInputStreamOperator) new ReplayOperator<>())
                             .setParallelism(dataStreamInputs.get(i).getParallelism());
             result.add(replayedInput);
         }
@@ -338,7 +338,7 @@ public class Iterations {
                         .name(terminationCriteria.getTransformation().getName())
                         .setParallelism(terminationCriteria.getParallelism());
         DataStreamList criteriaSources = DataStreamList.of(emptyCriteriaSource);
-        DataStreamList criteriaInputs = addInputs(criteriaSources, false);
+        DataStreamList criteriaInputs = addInputs(criteriaSources);
         DataStreamList criteriaHeaders =
                 addHeads(
                         criteriaSources,
@@ -398,8 +398,7 @@ public class Iterations {
         return map(dataStreams, DataStream::getType);
     }
 
-    private static DataStreamList addInputs(
-            DataStreamList dataStreams, boolean insertMaxEpochWatermark) {
+    private static DataStreamList addInputs(DataStreamList dataStreams) {
         return new DataStreamList(
                 map(
                         dataStreams,
@@ -408,7 +407,7 @@ public class Iterations {
                                         .transform(
                                                 "input-" + dataStream.getTransformation().getName(),
                                                 new IterationRecordTypeInfo<>(dataStream.getType()),
-                                                new InputOperator(insertMaxEpochWatermark))
+                                                new InputOperator())
                                         .setParallelism(dataStream.getParallelism())));
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index 3aebce9..80a6682 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.metrics.groups.InternalOperatorMetricGroup;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
@@ -47,7 +48,9 @@ import static org.apache.flink.util.Preconditions.checkState;
 
 /** The base class of all the wrapper operators. It provides the alignment functionality. */
 public abstract class AbstractWrapperOperator<T>
-        implements StreamOperator<IterationRecord<T>>, OperatorEpochWatermarkTrackerListener {
+        implements StreamOperator<IterationRecord<T>>,
+                OperatorEpochWatermarkTrackerListener,
+                BoundedMultiInput {
 
     private static final Logger LOG = LoggerFactory.getLogger(AbstractWrapperOperator.class);
 
@@ -130,6 +133,11 @@ public abstract class AbstractWrapperOperator<T>
         epochWatermarkSupplier.set(null);
     }
 
+    @Override
+    public void endInput(int i) throws Exception {
+        epochWatermarkTracker.finish(i - 1);
+    }
+
     private InternalOperatorMetricGroup createOperatorMetricGroup(
             Environment environment, StreamConfig streamConfig) {
         try {
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
index 5ef167e..b6908ff 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/InputOperator.java
@@ -20,21 +20,17 @@ package org.apache.flink.iteration.operator;
 
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.ChainingStrategy;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 
 /** Input operator that wraps the user record into {@link IterationRecord}. */
 public class InputOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
-        implements OneInputStreamOperator<T, IterationRecord<T>>, BoundedOneInput {
-
-    private final boolean insertMaxEpochWatermark;
+        implements OneInputStreamOperator<T, IterationRecord<T>> {
 
     private transient StreamRecord<IterationRecord<T>> reusable;
 
-    public InputOperator(boolean insertMaxEpochWatermark) {
-        this.insertMaxEpochWatermark = insertMaxEpochWatermark;
+    public InputOperator() {
         this.chainingStrategy = ChainingStrategy.ALWAYS;
     }
 
@@ -50,16 +46,4 @@ public class InputOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
         reusable.getValue().setValue(streamRecord.getValue());
         output.collect(reusable);
     }
-
-    @Override
-    public void endInput() throws Exception {
-        if (insertMaxEpochWatermark) {
-            reusable.replace(
-                    IterationRecord.newEpochWatermark(
-                            Integer.MAX_VALUE,
-                            OperatorUtils.getUniqueSenderId(
-                                    getOperatorID(), getRuntimeContext().getIndexOfThisSubtask())));
-            output.collect(reusable);
-        }
-    }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
index 2deb01e..16c6653 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/ReplayOperator.java
@@ -31,8 +31,9 @@ import org.apache.flink.iteration.typeinfo.IterationRecordSerializer;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
-import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.util.ExceptionUtils;
@@ -46,8 +47,10 @@ import static org.apache.flink.util.Preconditions.checkState;
 
 /** Replays the data received in the round 0 in the following round. */
 public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>>
-        implements OneInputStreamOperator<IterationRecord<T>, IterationRecord<T>>,
-                OperatorEpochWatermarkTrackerListener {
+        implements TwoInputStreamOperator<
+                        IterationRecord<T>, IterationRecord<Void>, IterationRecord<T>>,
+                OperatorEpochWatermarkTrackerListener,
+                BoundedMultiInput {
 
     private OperatorEpochWatermarkTracker progressTracker;
 
@@ -115,7 +118,7 @@ public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>
     }
 
     @Override
-    public void processElement(StreamRecord<IterationRecord<T>> element) throws Exception {
+    public void processElement1(StreamRecord<IterationRecord<T>> element) throws Exception {
         switch (element.getValue().getType()) {
             case RECORD:
                 dataCacheWriter.addRecord(element.getValue().getValue());
@@ -132,6 +135,23 @@ public class ReplayOperator<T> extends AbstractStreamOperator<IterationRecord<T>
     }
 
     @Override
+    public void processElement2(StreamRecord<IterationRecord<Void>> element) throws Exception {
+        if (element.getValue().getType() == IterationRecord.Type.EPOCH_WATERMARK) {
+            progressTracker.onEpochWatermark(
+                    1, element.getValue().getSender(), element.getValue().getEpoch());
+        } else {
+            throw new UnsupportedOperationException(
+                    "Not supported element type: " + element.getValue());
+        }
+    }
+
+    @Override
+    public void endInput(int i) throws Exception {
+        // The notification ranges from 1 to N while the track uses 0 to N -1.
+        progressTracker.finish(i - 1);
+    }
+
+    @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
         if (epochWatermark == 0) {
             // No need to replay for the round 0, it is output directly.
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index d3461a1..0ea742c 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -99,19 +99,18 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
-        if (epochWatermark <= latestEpochWatermark) {
-            return;
+        if (epochWatermark > latestEpochWatermark) {
+            latestEpochWatermark = epochWatermark;
+
+            setIterationContextRound(epochWatermark);
+            processOperatorOrUdfIfSatisfy(
+                    wrappedOperator,
+                    IterationListener.class,
+                    listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
+            clearIterationContextRound();
         }
-        latestEpochWatermark = epochWatermark;
 
-        setIterationContextRound(epochWatermark);
-        processOperatorOrUdfIfSatisfy(
-                wrappedOperator,
-                IterationListener.class,
-                listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
-        clearIterationContextRound();
-
-        // Broadcast the events.
+        // Always broadcasts the events.
         super.onEpochWatermarkIncrement(epochWatermark);
     }
 
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
index c975538..4d94298 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperator.java
@@ -80,6 +80,8 @@ public class MultipleInputAllRoundWrapperOperator<OUT>
 
     @Override
     public void endInput(int i) throws Exception {
+        super.endInput(i);
+
         if (wrappedOperator instanceof BoundedMultiInput) {
             ((BoundedMultiInput) wrappedOperator).endInput(i);
         }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
index 2c633c1..bedcccf 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperator.java
@@ -113,6 +113,8 @@ public class TwoInputAllRoundWrapperOperator<IN1, IN2, OUT>
 
     @Override
     public void endInput(int i) throws Exception {
+        super.endInput(i);
+
         if (wrappedOperator instanceof BoundedMultiInput) {
             ((BoundedMultiInput) wrappedOperator).endInput(i);
         }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
index e2d716b..33ddee8 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTracker.java
@@ -60,9 +60,21 @@ public class OperatorEpochWatermarkTracker {
         InputStatus inputStatus = inputStatuses.get(inputIndex);
         inputStatus.onUpdate(sender, epochWatermark);
 
-        if (inputStatus.getInputLowerBound() > allInputsLowerBound.getValue(inputIndex)) {
+        tryUpdateLowerBound(inputIndex);
+    }
+
+    public void finish(int inputIndex) throws IOException {
+        inputStatuses.get(inputIndex).finish();
+
+        tryUpdateLowerBound(inputIndex);
+    }
+
+    private void tryUpdateLowerBound(int changedInputIndex) throws IOException {
+        if (inputStatuses.get(changedInputIndex).getInputLowerBound()
+                > allInputsLowerBound.getValue(changedInputIndex)) {
             int oldLowerBound = allInputsLowerBound.getLowerBound();
-            allInputsLowerBound.updateValue(inputIndex, inputStatus.getInputLowerBound());
+            allInputsLowerBound.updateValue(
+                    changedInputIndex, inputStatuses.get(changedInputIndex).getInputLowerBound());
             if (allInputsLowerBound.getLowerBound() > oldLowerBound) {
                 progressTrackerListener.onEpochWatermarkIncrement(
                         allInputsLowerBound.getLowerBound());
@@ -95,6 +107,12 @@ public class OperatorEpochWatermarkTracker {
             allChannelsLowerBound.updateValue(index, epochWatermark);
         }
 
+        public void finish() {
+            for (int i = 0; i < numberOfChannels; ++i) {
+                allChannelsLowerBound.updateValue(i, Integer.MAX_VALUE);
+            }
+        }
+
         public int getInputLowerBound() {
             return allChannelsLowerBound.getLowerBound();
         }
@@ -122,7 +140,7 @@ public class OperatorEpochWatermarkTracker {
 
         public void updateValue(int channel, int value) {
             checkState(
-                    value > values[channel],
+                    value >= values[channel],
                     String.format(
                             "The channel %d received an outdated value %d, which currently is %d",
                             channel, value, values[channel]));
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
index 3cd4bf2..f2ec465 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/IterationConstructionTest.java
@@ -378,7 +378,7 @@ public class IterationConstructionTest extends TestLogger {
                         /* 0 */ "Source: Variable -> input-Variable",
                         /* 1 */ "Source: Constant -> input-Constant",
                         /* 2 */ "Source: Termination -> input-Termination",
-                        /* 3 */ "head-Variable -> signal-change-typeinfo",
+                        /* 3 */ "head-Variable",
                         /* 4 */ "Replayer-Constant",
                         /* 5 */ "Processor -> output-SideOutput -> Sink: Sink",
                         /* 6 */ "Feedback",
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
index 03cba7a..bba3fec 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/InputOperatorTest.java
@@ -34,7 +34,7 @@ public class InputOperatorTest extends TestLogger {
     @Test
     public void testWrapRecord() throws Exception {
         OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(false));
+                new OneInputStreamOperatorTestHarness<>(new InputOperator<>());
         testHarness.open();
 
         ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
@@ -46,36 +46,4 @@ public class InputOperatorTest extends TestLogger {
         TestHarnessUtil.assertOutputEquals(
                 "Output was not correct", expectedOutput, testHarness.getOutput());
     }
-
-    @Test
-    public void testInsertMaxEpochWatermarkIfSpecified() throws Exception {
-        OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(true));
-        testHarness.open();
-
-        testHarness.endInput();
-
-        ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
-        expectedOutput.add(
-                new StreamRecord<>(
-                        IterationRecord.newEpochWatermark(
-                                Integer.MAX_VALUE,
-                                OperatorUtils.getUniqueSenderId(
-                                        testHarness.getOperator().getOperatorID(), 0))));
-
-        TestHarnessUtil.assertOutputEquals(
-                "Output was not correct", expectedOutput, testHarness.getOutput());
-    }
-
-    @Test
-    public void testNotInsertMaxEpochWatermarkIfSpecified() throws Exception {
-        OneInputStreamOperatorTestHarness<Integer, IterationRecord<Integer>> testHarness =
-                new OneInputStreamOperatorTestHarness<>(new InputOperator<>(false));
-        testHarness.open();
-
-        testHarness.endInput();
-
-        TestHarnessUtil.assertOutputEquals(
-                "Output was not correct", new ConcurrentLinkedQueue<>(), testHarness.getOutput());
-    }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
index d17c78f..97d61ca 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/ReplayOperatorTest.java
@@ -24,9 +24,9 @@ import org.apache.flink.iteration.config.IterationOptions;
 import org.apache.flink.iteration.typeinfo.IterationRecordTypeInfo;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
 
 import org.junit.Rule;
 import org.junit.Test;
@@ -52,9 +52,10 @@ public class ReplayOperatorTest {
 
         try (StreamTaskMailboxTestHarness<IterationRecord<Integer>> harness =
                 new StreamTaskMailboxTestHarnessBuilder<>(
-                                OneInputStreamTask::new,
+                                TwoInputStreamTask::new,
                                 new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO))
-                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), 2)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.INT_TYPE_INFO), 1)
+                        .addInput(new IterationRecordTypeInfo<>(BasicTypeInfo.VOID_TYPE_INFO), 1)
                         .setupOutputForSingletonOperatorChain(new ReplayOperator<>(), operatorId)
                         .buildUnrestored()) {
             harness.getStreamTask()
@@ -70,19 +71,16 @@ public class ReplayOperatorTest {
             for (int i = 0; i < numRecords; ++i) {
                 harness.processElement(new StreamRecord<>(IterationRecord.newRecord(i, 0)), 0, 0);
             }
+            harness.endInput(0, true);
             harness.processElement(
-                    new StreamRecord<>(
-                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "sender0")),
-                    0,
-                    0);
-            harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(0, "sender1")), 1, 0);
             assertOutputAllRecordsAndEpochWatermark(harness.getOutput(), numRecords, operatorId, 0);
             harness.getOutput().clear();
 
             // The round 1
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(1, "sender1")), 1, 0);
+            // The output would be done asynchronously inside the ReplayerOperator.
             while (harness.getOutput().size() < numRecords + 1) {
                 Thread.sleep(500);
             }
@@ -91,7 +89,8 @@ public class ReplayOperatorTest {
 
             // The round 2
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "sender1")), 0, 1);
+                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "sender1")), 1, 0);
+            // The output would be done asynchronously inside the ReplayerOperator.
             while (harness.getOutput().size() < numRecords + 1) {
                 Thread.sleep(500);
             }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
index 73a95b1..4935484 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/progresstrack/OperatorEpochWatermarkTrackerTest.java
@@ -59,6 +59,23 @@ public class OperatorEpochWatermarkTrackerTest extends TestLogger {
         assertEquals(Collections.singletonList(3), recordingProgressListener.notifications);
     }
 
+    @Test
+    public void testFinish() throws IOException {
+        RecordingProgressListener recordingProgressListener = new RecordingProgressListener();
+        int[] numberOfChannels = new int[] {2, 3};
+        OperatorEpochWatermarkTracker progressTracker =
+                new OperatorEpochWatermarkTracker(numberOfChannels, recordingProgressListener);
+        progressTracker.finish(0);
+
+        testOnEpochWatermark(
+                new int[] {0, 0, 1},
+                progressTracker,
+                recordingProgressListener,
+                new int[] {1, 1, 1},
+                new String[] {"1-1", "1-2", "1-3"},
+                3);
+    }
+
     private void testOnEpochWatermark(
             int[] expectedNumNotifications,
             OperatorEpochWatermarkTracker tracker,