You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by ga...@apache.org on 2021/11/15 03:16:03 UTC

[flink-ml] branch master updated (acbf4b9 -> 5407627)

This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git.


    from acbf4b9  [FLINK-24807][iteration] Not start logging at the head operator if the barrier feed back first
     add efb9c14  [hotfix][iteration] Fix the bad compile due to changes of state
     new 5407627  [FLINK-24808][iteration] Support IterationListener for per-round operators

The 1 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../operator/AbstractWrapperOperator.java          | 28 ++++++++
 .../allround/AbstractAllRoundWrapperOperator.java  | 27 --------
 .../perround/AbstractPerRoundWrapperOperator.java  | 49 +++++++++-----
 .../MultipleInputPerRoundWrapperOperator.java      | 78 ++++++++++++----------
 .../perround/OneInputPerRoundWrapperOperator.java  |  3 +-
 .../perround/TwoInputPerRoundWrapperOperator.java  |  3 +-
 .../iteration/operator/allround/LifeCycle.java     |  2 +
 .../MultipleInputAllRoundWrapperOperatorTest.java  | 37 +++++++++-
 .../OneInputAllRoundWrapperOperatorTest.java       | 28 +++++++-
 .../TwoInputAllRoundWrapperOperatorTest.java       | 33 ++++++++-
 .../MultipleInputPerRoundWrapperOperatorTest.java  | 34 ++++++++--
 .../OneInputPerRoundWrapperOperatorTest.java       | 25 ++++++-
 .../TwoInputPerRoundWrapperOperatorTest.java       | 30 +++++++--
 .../operator/AbstractBroadcastWrapperOperator.java |  6 +-
 pom.xml                                            |  1 +
 15 files changed, 284 insertions(+), 100 deletions(-)

[flink-ml] 01/01: [FLINK-24808][iteration] Support IterationListener for per-round operators

Posted by ga...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git

commit 540762700ce7cfe502a62a81505cc133b9616031
Author: Yun Gao <ga...@gmail.com>
AuthorDate: Sat Nov 6 14:27:58 2021 +0800

    [FLINK-24808][iteration] Support IterationListener for per-round operators
    
    This closes #26.
---
 .../operator/AbstractWrapperOperator.java          | 28 ++++++++
 .../allround/AbstractAllRoundWrapperOperator.java  | 27 --------
 .../perround/AbstractPerRoundWrapperOperator.java  | 49 +++++++++-----
 .../MultipleInputPerRoundWrapperOperator.java      | 78 ++++++++++++----------
 .../perround/OneInputPerRoundWrapperOperator.java  |  3 +-
 .../perround/TwoInputPerRoundWrapperOperator.java  |  3 +-
 .../iteration/operator/allround/LifeCycle.java     |  2 +
 .../MultipleInputAllRoundWrapperOperatorTest.java  | 37 +++++++++-
 .../OneInputAllRoundWrapperOperatorTest.java       | 28 +++++++-
 .../TwoInputAllRoundWrapperOperatorTest.java       | 33 ++++++++-
 .../MultipleInputPerRoundWrapperOperatorTest.java  | 34 ++++++++--
 .../OneInputPerRoundWrapperOperatorTest.java       | 25 ++++++-
 .../TwoInputPerRoundWrapperOperatorTest.java       | 30 +++++++--
 13 files changed, 278 insertions(+), 99 deletions(-)

diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
index 80a6682..86dbc2f 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/AbstractWrapperOperator.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.iteration.operator;
 
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.broadcast.BroadcastOutput;
 import org.apache.flink.iteration.broadcast.BroadcastOutputFactory;
@@ -34,8 +35,10 @@ import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.OutputTag;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -64,6 +67,8 @@ public abstract class AbstractWrapperOperator<T>
 
     protected final StreamOperatorFactory<T> operatorFactory;
 
+    protected final IterationContext iterationContext;
+
     // --------------- proxy ---------------------------
 
     protected final ProxyOutput<T> proxyOutput;
@@ -105,6 +110,7 @@ public abstract class AbstractWrapperOperator<T>
         this.eventBroadcastOutput =
                 BroadcastOutputFactory.createBroadcastOutput(
                         output, metrics.getIOMetricGroup().getNumRecordsOutCounter());
+        this.iterationContext = new IterationContext();
     }
 
     protected void onEpochWatermarkEvent(int inputIndex, IterationRecord<?> iterationRecord)
@@ -116,6 +122,20 @@ public abstract class AbstractWrapperOperator<T>
                 inputIndex, iterationRecord.getSender(), iterationRecord.getEpoch());
     }
 
+    @SuppressWarnings({"unchecked", "rawtypes"})
+    protected void notifyEpochWatermarkIncrement(
+            IterationListener<?> listener, int epochWatermark) {
+        if (epochWatermark != Integer.MAX_VALUE) {
+            listener.onEpochWatermarkIncremented(
+                    epochWatermark,
+                    iterationContext,
+                    new TimestampedCollector<>((Output) proxyOutput));
+        } else {
+            listener.onIterationTerminated(
+                    iterationContext, new TimestampedCollector<>((Output) proxyOutput));
+        }
+    }
+
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
         eventBroadcastOutput.broadcastEmit(
@@ -156,6 +176,14 @@ public abstract class AbstractWrapperOperator<T>
         }
     }
 
+    private class IterationContext implements IterationListener.Context {
+
+        @Override
+        public <X> void output(OutputTag<X> outputTag, X value) {
+            proxyOutput.collect(outputTag, new StreamRecord<>(value));
+        }
+    }
+
     private static class EpochSupplier implements Supplier<Integer> {
 
         private Integer epoch;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
index 180477c..1c2d0b5 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/allround/AbstractAllRoundWrapperOperator.java
@@ -37,18 +37,15 @@ import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.streaming.api.operators.KeyContext;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
-import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactory;
 import org.apache.flink.streaming.api.operators.StreamOperatorFactoryUtil;
 import org.apache.flink.streaming.api.operators.StreamOperatorParameters;
 import org.apache.flink.streaming.api.operators.StreamOperatorStateContext;
 import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
-import org.apache.flink.streaming.api.operators.TimestampedCollector;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.ProcessingTimeService;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
-import org.apache.flink.util.OutputTag;
 
 import javax.annotation.Nonnull;
 import javax.annotation.Nullable;
@@ -65,8 +62,6 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
 
     protected final S wrappedOperator;
 
-    private final IterationContext iterationContext;
-
     // --------------- state ---------------------------
     private int latestEpochWatermark = -1;
 
@@ -96,7 +91,6 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
                 EpochAware.class,
                 epochWatermarkAware ->
                         epochWatermarkAware.setEpochSupplier(epochWatermarkSupplier));
-        this.iterationContext = new IterationContext();
     }
 
     @Override
@@ -116,19 +110,6 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
         super.onEpochWatermarkIncrement(epochWatermark);
     }
 
-    @SuppressWarnings({"unchecked", "rawtypes"})
-    private void notifyEpochWatermarkIncrement(IterationListener<?> listener, int epochWatermark) {
-        if (epochWatermark != Integer.MAX_VALUE) {
-            listener.onEpochWatermarkIncremented(
-                    epochWatermark,
-                    iterationContext,
-                    new TimestampedCollector<>((Output) proxyOutput));
-        } else {
-            listener.onIterationTerminated(
-                    iterationContext, new TimestampedCollector<>((Output) proxyOutput));
-        }
-    }
-
     @Override
     public void initializeState(StreamTaskStateInitializer streamTaskStateManager)
             throws Exception {
@@ -257,14 +238,6 @@ public abstract class AbstractAllRoundWrapperOperator<T, S extends StreamOperato
         return latestEpochWatermark;
     }
 
-    private class IterationContext implements IterationListener.Context {
-
-        @Override
-        public <X> void output(OutputTag<X> outputTag, X value) {
-            proxyOutput.collect(outputTag, new StreamRecord<>(value));
-        }
-    }
-
     private static class RecordingStreamTaskStateInitializer implements StreamTaskStateInitializer {
 
         private final StreamTaskStateInitializer wrapped;
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
index 4762c74..bdc1d26 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/AbstractPerRoundWrapperOperator.java
@@ -28,6 +28,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.MetricOptions;
 import org.apache.flink.contrib.streaming.state.RocksDBKeyedStateBackend;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.AbstractWrapperOperator;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
@@ -178,37 +179,52 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
         return wrappedOperator;
     }
 
-    protected abstract void endInputAndEmitMaxWatermark(S operator, int round) throws Exception;
+    protected abstract void endInputAndEmitMaxWatermark(S operator, int epoch, int epochWatermark)
+            throws Exception;
 
-    private void closeStreamOperator(S operator, int round) throws Exception {
-        setIterationContextRound(round);
-        endInputAndEmitMaxWatermark(operator, round);
+    protected void closeStreamOperator(S operator, int epoch, int epochWatermark) throws Exception {
+        setIterationContextRound(epoch);
+        OperatorUtils.processOperatorOrUdfIfSatisfy(
+                operator,
+                IterationListener.class,
+                listener -> notifyEpochWatermarkIncrement(listener, epochWatermark));
+        endInputAndEmitMaxWatermark(operator, epoch, epochWatermark);
         operator.finish();
         operator.close();
         setIterationContextRound(null);
 
         // Cleanup the states used by this operator.
-        cleanupOperatorStates(round);
+        cleanupOperatorStates(epoch);
 
         if (stateHandler.getKeyedStateBackend() != null) {
-            cleanupKeyedStates(round);
+            cleanupKeyedStates(epoch);
         }
     }
 
     @Override
     public void onEpochWatermarkIncrement(int epochWatermark) throws IOException {
+        checkState(epochWatermark >= 0, "The epoch watermark should be non-negative.");
         if (epochWatermark > latestEpochWatermark) {
             latestEpochWatermark = epochWatermark;
 
             // Destroys all the operators with round < epoch watermark. Notes that
-            // the onEpochWatermarkIncrement must be from 0 and increment by 1 each time.
-            if (wrappedOperators.containsKey(epochWatermark)) {
-                try {
-                    closeStreamOperator(wrappedOperators.get(epochWatermark), epochWatermark);
-                } catch (Exception exception) {
-                    ExceptionUtils.rethrow(exception);
+            // the onEpochWatermarkIncrement must be from 0 and increment by 1 each time, except
+            // for the last round.
+            try {
+                if (epochWatermark < Integer.MAX_VALUE) {
+                    S wrappedOperator = wrappedOperators.remove(epochWatermark);
+                    if (wrappedOperator != null) {
+                        closeStreamOperator(wrappedOperator, epochWatermark, epochWatermark);
+                    }
+                } else {
+                    List<Integer> sortedEpochs = new ArrayList<>(wrappedOperators.keySet());
+                    Collections.sort(sortedEpochs);
+                    for (Integer epoch : sortedEpochs) {
+                        closeStreamOperator(wrappedOperators.remove(epoch), epoch, epochWatermark);
+                    }
                 }
-                wrappedOperators.remove(epochWatermark);
+            } catch (Exception exception) {
+                ExceptionUtils.rethrow(exception);
             }
         }
 
@@ -340,10 +356,9 @@ public abstract class AbstractPerRoundWrapperOperator<T, S extends StreamOperato
 
     @Override
     public void finish() throws Exception {
-        for (Map.Entry<Integer, S> entry : wrappedOperators.entrySet()) {
-            closeStreamOperator(entry.getValue(), entry.getKey());
-        }
-        wrappedOperators.clear();
+        checkState(
+                wrappedOperators.size() == 0,
+                "Some wrapped operators are still not closed yet: " + wrappedOperators.keySet());
     }
 
     @Override
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
index 7c6240f..c7ebbbb 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperator.java
@@ -66,7 +66,15 @@ public class MultipleInputPerRoundWrapperOperator<OUT>
     }
 
     @Override
-    protected void endInputAndEmitMaxWatermark(MultipleInputStreamOperator<OUT> operator, int round)
+    protected MultipleInputStreamOperator<OUT> getWrappedOperator(int epoch) {
+        MultipleInputStreamOperator<OUT> operator = super.getWrappedOperator(epoch);
+        operatorInputsByEpoch.put(epoch, operator.getInputs());
+        return operator;
+    }
+
+    @Override
+    protected void endInputAndEmitMaxWatermark(
+            MultipleInputStreamOperator<OUT> operator, int epoch, int epochWatermark)
             throws Exception {
         OperatorUtils.processOperatorOrUdfIfSatisfy(
                 operator,
@@ -78,29 +86,16 @@ public class MultipleInputPerRoundWrapperOperator<OUT>
                 });
 
         for (int i = 0; i < numberOfInputs; ++i) {
-            operatorInputsByEpoch.get(round).get(i).processWatermark(new Watermark(Long.MAX_VALUE));
+            operatorInputsByEpoch.get(epoch).get(i).processWatermark(new Watermark(Long.MAX_VALUE));
         }
     }
 
-    private <IN> void processElement(
-            int inputIndex,
-            Input<IN> input,
-            StreamRecord<IN> reusedInput,
-            StreamRecord<IterationRecord<IN>> element)
+    @Override
+    protected void closeStreamOperator(
+            MultipleInputStreamOperator<OUT> operator, int epoch, int epochWatermark)
             throws Exception {
-        switch (element.getValue().getType()) {
-            case RECORD:
-                reusedInput.replace(element.getValue().getValue(), element.getTimestamp());
-                setIterationContextRound(element.getValue().getEpoch());
-                input.processElement(reusedInput);
-                clearIterationContextRound();
-                break;
-            case EPOCH_WATERMARK:
-                onEpochWatermarkEvent(inputIndex, element.getValue());
-                break;
-            default:
-                throw new FlinkRuntimeException("Not supported iteration record type: " + element);
-        }
+        super.closeStreamOperator(operator, epoch, epochWatermark);
+        operatorInputsByEpoch.remove(epoch);
     }
 
     @Override
@@ -130,17 +125,25 @@ public class MultipleInputPerRoundWrapperOperator<OUT>
 
         @Override
         public void processElement(StreamRecord<IterationRecord<IN>> element) throws Exception {
-            if (!operatorInputsByEpoch.containsKey(element.getValue().getEpoch())) {
-                MultipleInputStreamOperator<OUT> operator =
-                        getWrappedOperator(element.getValue().getEpoch());
-                operatorInputsByEpoch.put(element.getValue().getEpoch(), operator.getInputs());
+            switch (element.getValue().getType()) {
+                case RECORD:
+                    // Ensures the operators are created.
+                    getWrappedOperator(element.getValue().getEpoch());
+                    reusedInput.replace(element.getValue().getValue(), element.getTimestamp());
+                    setIterationContextRound(element.getValue().getEpoch());
+                    operatorInputsByEpoch
+                            .get(element.getValue().getEpoch())
+                            .get(inputIndex)
+                            .processElement(reusedInput);
+                    clearIterationContextRound();
+                    break;
+                case EPOCH_WATERMARK:
+                    onEpochWatermarkEvent(inputIndex, element.getValue());
+                    break;
+                default:
+                    throw new FlinkRuntimeException(
+                            "Not supported iteration record type: " + element);
             }
-
-            MultipleInputPerRoundWrapperOperator.this.processElement(
-                    inputIndex,
-                    operatorInputsByEpoch.get(element.getValue().getEpoch()).get(inputIndex),
-                    reusedInput,
-                    element);
         }
 
         @Override
@@ -168,13 +171,18 @@ public class MultipleInputPerRoundWrapperOperator<OUT>
         }
 
         @Override
-        public void setKeyContextElement(StreamRecord<IterationRecord<IN>> record)
+        public void setKeyContextElement(StreamRecord<IterationRecord<IN>> element)
                 throws Exception {
-            MultipleInputStreamOperator<OUT> operator =
-                    getWrappedOperator(record.getValue().getEpoch());
 
-            reusedInput.replace(record.getValue(), record.getTimestamp());
-            operator.getInputs().get(inputIndex).setKeyContextElement(reusedInput);
+            if (element.getValue().getType() == IterationRecord.Type.RECORD) {
+                // Ensures the operators are created.
+                getWrappedOperator(element.getValue().getEpoch());
+                reusedInput.replace(element.getValue(), element.getTimestamp());
+                operatorInputsByEpoch
+                        .get(element.getValue().getEpoch())
+                        .get(inputIndex)
+                        .setKeyContextElement(reusedInput);
+            }
         }
     }
 }
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
index ebfb3be..e3847de 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperator.java
@@ -51,7 +51,8 @@ public class OneInputPerRoundWrapperOperator<IN, OUT>
     }
 
     @Override
-    protected void endInputAndEmitMaxWatermark(OneInputStreamOperator<IN, OUT> operator, int round)
+    protected void endInputAndEmitMaxWatermark(
+            OneInputStreamOperator<IN, OUT> operator, int epoch, int epochWatermark)
             throws Exception {
         OperatorUtils.processOperatorOrUdfIfSatisfy(
                 operator, BoundedOneInput.class, BoundedOneInput::endInput);
diff --git a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
index a8a8d75..ee409da 100644
--- a/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
+++ b/flink-ml-iteration/src/main/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperator.java
@@ -52,7 +52,8 @@ public class TwoInputPerRoundWrapperOperator<IN1, IN2, OUT>
 
     @Override
     protected void endInputAndEmitMaxWatermark(
-            TwoInputStreamOperator<IN1, IN2, OUT> operator, int round) throws Exception {
+            TwoInputStreamOperator<IN1, IN2, OUT> operator, int epoch, int epochWatermark)
+            throws Exception {
         OperatorUtils.processOperatorOrUdfIfSatisfy(
                 operator,
                 BoundedMultiInput.class,
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
index 552c933..8229b04 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/LifeCycle.java
@@ -30,6 +30,8 @@ public enum LifeCycle {
     SNAPSHOT_STATE,
     NOTIFY_CHECKPOINT_COMPLETE,
     NOTIFY_CHECKPOINT_ABORT,
+    EPOCH_WATERMARK_INCREMENTED,
+    ITERATION_TERMINATION,
     END_INPUT,
     MAX_WATERMARK,
     FINISH,
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
index 9a2b72a..1eb5734 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/MultipleInputAllRoundWrapperOperatorTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.allround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
@@ -45,6 +46,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -86,6 +88,18 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1);
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-2")), 2);
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+                    0);
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+                    1);
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-2")),
+                    2);
 
             // Checks the output
             assertEquals(
@@ -95,7 +109,11 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
                             new StreamRecord<>(IterationRecord.newRecord(7, 3), 4),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             // Checks the other lifecycles.
@@ -129,6 +147,8 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.PROCESS_ELEMENT,
                             LifeCycle.PROCESS_ELEMENT,
                             LifeCycle.PROCESS_ELEMENT,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
+                            LifeCycle.ITERATION_TERMINATION,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
@@ -147,7 +167,9 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperatorV2<Integer>
-            implements MultipleInputStreamOperator<Integer>, BoundedMultiInput {
+            implements MultipleInputStreamOperator<Integer>,
+                    BoundedMultiInput,
+                    IterationListener<Integer> {
 
         private final int numberOfInputs;
 
@@ -226,6 +248,17 @@ public class MultipleInputAllRoundWrapperOperatorTest extends TestLogger {
         public void endInput(int inputId) throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 
     /** The operator factory for the lifecycle-tracking operator. */
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
index 9ebb975..fd75d65 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/OneInputAllRoundWrapperOperatorTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.allround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.OperatorWrapper;
@@ -48,6 +49,7 @@ import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -83,6 +85,9 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
             harness.processElement(new StreamRecord<>(IterationRecord.newRecord(6, 2), 3));
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one")));
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")));
 
             // Checks the output
             assertEquals(
@@ -91,7 +96,11 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
                             new StreamRecord<>(IterationRecord.newRecord(6, 2), 3),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             // Check the other lifecycles.
@@ -123,6 +132,8 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT,
                             LifeCycle.PROCESS_ELEMENT,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
+                            LifeCycle.ITERATION_TERMINATION,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
@@ -201,7 +212,9 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {
+            implements OneInputStreamOperator<Integer, Integer>,
+                    BoundedOneInput,
+                    IterationListener {
 
         @Override
         public void setup(
@@ -270,5 +283,16 @@ public class OneInputAllRoundWrapperOperatorTest extends TestLogger {
         public void endInput() throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
index 82d5854..5930e8b 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/allround/TwoInputAllRoundWrapperOperatorTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.allround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
@@ -44,6 +45,7 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
 import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -81,6 +83,14 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-0")), 0);
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(5, "only-one-1")), 1);
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+                    0);
+            harness.processElement(
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+                    1);
 
             // Checks the output
             assertEquals(
@@ -89,7 +99,11 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
                             new StreamRecord<>(IterationRecord.newRecord(6, 2), 3),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            5, OperatorUtils.getUniqueSenderId(operatorId, 0))),
+                            new StreamRecord<>(
+                                    IterationRecord.newEpochWatermark(
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             // Checks the other lifecycles.
@@ -122,6 +136,8 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.OPEN,
                             LifeCycle.PROCESS_ELEMENT_1,
                             LifeCycle.PROCESS_ELEMENT_2,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
+                            LifeCycle.ITERATION_TERMINATION,
                             LifeCycle.PREPARE_SNAPSHOT_PRE_BARRIER,
                             LifeCycle.SNAPSHOT_STATE,
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
@@ -138,7 +154,9 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements TwoInputStreamOperator<Integer, Integer, Integer>, BoundedMultiInput {
+            implements TwoInputStreamOperator<Integer, Integer, Integer>,
+                    BoundedMultiInput,
+                    IterationListener<Integer> {
 
         @Override
         public void setup(
@@ -213,5 +231,16 @@ public class TwoInputAllRoundWrapperOperatorTest extends TestLogger {
         public void endInput(int inputId) throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 }
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
index 9dba5fa..59685a9 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/MultipleInputPerRoundWrapperOperatorTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.perround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
@@ -47,6 +48,7 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.MultipleInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -115,11 +117,17 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one-2")), 2);
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-0")), 0);
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-0")),
+                    0);
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-1")), 1);
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-1")),
+                    1);
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one-2")), 2);
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one-2")),
+                    2);
 
             // Checks the output
             assertEquals(
@@ -129,7 +137,8 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
                                             1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            2, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             harness.processEvent(EndOfData.INSTANCE, 0);
@@ -157,6 +166,7 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
                             // The first input
                             LifeCycle.END_INPUT,
                             // The second input
@@ -165,6 +175,7 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE,
+                            LifeCycle.ITERATION_TERMINATION,
                             // The first input
                             LifeCycle.END_INPUT,
                             // The second input
@@ -179,7 +190,9 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingMultiInputStreamOperator
             extends AbstractStreamOperatorV2<Integer>
-            implements MultipleInputStreamOperator<Integer>, BoundedMultiInput {
+            implements MultipleInputStreamOperator<Integer>,
+                    BoundedMultiInput,
+                    IterationListener<Integer> {
 
         private final int numberOfInputs;
 
@@ -258,6 +271,17 @@ public class MultipleInputPerRoundWrapperOperatorTest extends TestLogger {
         public void endInput(int inputId) throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 
     /** Life-cycle tracking stream operator factory. */
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
index 214fb79..43460a7 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/OneInputPerRoundWrapperOperatorTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.OperatorWrapper;
@@ -54,6 +55,7 @@ import org.apache.flink.streaming.runtime.tasks.OneInputStreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.apache.commons.collections.IteratorUtils;
@@ -125,7 +127,8 @@ public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")));
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one")));
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")));
 
             // Checks the output
             assertEquals(
@@ -135,7 +138,8 @@ public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
                                             1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            2, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             harness.processEvent(EndOfData.INSTANCE, 0);
@@ -163,9 +167,11 @@ public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
                             LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE,
+                            LifeCycle.ITERATION_TERMINATION,
                             LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE),
@@ -264,7 +270,9 @@ public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingOneInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements OneInputStreamOperator<Integer, Integer>, BoundedOneInput {
+            implements OneInputStreamOperator<Integer, Integer>,
+                    BoundedOneInput,
+                    IterationListener<Integer> {
 
         @Override
         public void setup(
@@ -333,6 +341,17 @@ public class OneInputPerRoundWrapperOperatorTest extends TestLogger {
         public void endInput() throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 
     private static class StatefulOperator extends AbstractStreamOperator<Integer>
diff --git a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
index f2134b8..6e08e57 100644
--- a/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
+++ b/flink-ml-iteration/src/test/java/org/apache/flink/iteration/operator/perround/TwoInputPerRoundWrapperOperatorTest.java
@@ -19,6 +19,7 @@
 package org.apache.flink.iteration.operator.perround;
 
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.iteration.IterationListener;
 import org.apache.flink.iteration.IterationRecord;
 import org.apache.flink.iteration.operator.OperatorUtils;
 import org.apache.flink.iteration.operator.WrapperOperatorFactory;
@@ -45,6 +46,7 @@ import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarness;
 import org.apache.flink.streaming.runtime.tasks.StreamTaskMailboxTestHarnessBuilder;
 import org.apache.flink.streaming.runtime.tasks.TwoInputStreamTask;
+import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 
 import org.junit.Test;
@@ -110,9 +112,13 @@ public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
             harness.processElement(
                     new StreamRecord<>(IterationRecord.newEpochWatermark(1, "only-one")), 1);
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one")), 0);
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")),
+                    0);
             harness.processElement(
-                    new StreamRecord<>(IterationRecord.newEpochWatermark(2, "only-one")), 1);
+                    new StreamRecord<>(
+                            IterationRecord.newEpochWatermark(Integer.MAX_VALUE, "only-one")),
+                    1);
 
             // Checks the output
             assertEquals(
@@ -122,7 +128,8 @@ public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
                                             1, OperatorUtils.getUniqueSenderId(operatorId, 0))),
                             new StreamRecord<>(
                                     IterationRecord.newEpochWatermark(
-                                            2, OperatorUtils.getUniqueSenderId(operatorId, 0)))),
+                                            Integer.MAX_VALUE,
+                                            OperatorUtils.getUniqueSenderId(operatorId, 0)))),
                     new ArrayList<>(harness.getOutput()));
 
             harness.processEvent(EndOfData.INSTANCE, 0);
@@ -151,12 +158,14 @@ public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
                             LifeCycle.NOTIFY_CHECKPOINT_COMPLETE,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
                             LifeCycle.NOTIFY_CHECKPOINT_ABORT,
+                            LifeCycle.EPOCH_WATERMARK_INCREMENTED,
                             // The first input
                             LifeCycle.END_INPUT,
                             // The second input
                             LifeCycle.END_INPUT,
                             LifeCycle.FINISH,
                             LifeCycle.CLOSE,
+                            LifeCycle.ITERATION_TERMINATION,
                             // The first input
                             LifeCycle.END_INPUT,
                             // The second input
@@ -169,7 +178,9 @@ public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
 
     private static class LifeCycleTrackingTwoInputStreamOperator
             extends AbstractStreamOperator<Integer>
-            implements TwoInputStreamOperator<Integer, Integer, Integer>, BoundedMultiInput {
+            implements TwoInputStreamOperator<Integer, Integer, Integer>,
+                    BoundedMultiInput,
+                    IterationListener<Integer> {
 
         @Override
         public void setup(
@@ -244,5 +255,16 @@ public class TwoInputPerRoundWrapperOperatorTest extends TestLogger {
         public void endInput(int inputId) throws Exception {
             LIFE_CYCLES.add(LifeCycle.END_INPUT);
         }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.EPOCH_WATERMARK_INCREMENTED);
+        }
+
+        @Override
+        public void onIterationTerminated(Context context, Collector<Integer> collector) {
+            LIFE_CYCLES.add(LifeCycle.ITERATION_TERMINATION);
+        }
     }
 }