You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2021/12/08 07:18:15 UTC

[flink] branch master updated (8a25593 -> 2b167ae)

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git.


    from 8a25593  [hotfix][docs][hive] Minor update to "Hive Read & Write " page (#16529)
     new 72a2471  [FLINK-23532] Pass a flag for draining along with EndOfData
     new 5a0dcfb  [hotfix] Remove duplicated handling of stop-with-savepoint trigger failure
     new 0bd397b  [FLINK-23532] Unify stop-with-savepoint without and with drain
     new 2b167ae  [FLINK-23532] Remove unnecessary StreamTask#finishTask

The 4 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../flink/runtime/io/PullingAsyncDataInput.java    |  16 +-
 .../flink/runtime/io/network/api/EndOfData.java    |  44 +++--
 .../flink/runtime/io/network/api/StopMode.java     |  12 +-
 .../network/api/serialization/EventSerializer.java |  12 +-
 .../network/api/writer/ResultPartitionWriter.java  |   6 +-
 .../partition/BoundedBlockingResultPartition.java  |   5 +-
 .../partition/PipelinedResultPartition.java        |   5 +-
 .../io/network/partition/ResultPartition.java      |   3 +-
 .../partition/SortMergeResultPartition.java        |   5 +-
 .../io/network/partition/consumer/InputGate.java   |   2 -
 .../partition/consumer/SingleInputGate.java        |  13 +-
 .../network/partition/consumer/UnionInputGate.java |  14 +-
 ...bleNotifyingResultPartitionWriterDecorator.java |   5 +-
 .../runtime/taskmanager/InputGateWithMetrics.java  |   2 +-
 .../api/serialization/EventSerializerTest.java     |   4 +-
 ...cordOrEventCollectingResultPartitionWriter.java |   5 +-
 .../netty/PartitionRequestServerHandlerTest.java   |   3 +-
 .../BoundedBlockingSubpartitionWriteReadTest.java  |   3 +-
 .../partition/MockResultPartitionWriter.java       |   3 +-
 .../io/network/partition/ResultPartitionTest.java  |   5 +-
 .../partition/consumer/SingleInputGateTest.java    |  62 ++++++-
 .../partition/consumer/TestInputChannel.java       |   7 +-
 .../partition/consumer/UnionInputGateTest.java     |  54 +++++-
 .../streaming/api/operators/SourceOperator.java    |  29 +++-
 .../runtime/io/AbstractStreamTaskNetworkInput.java |  10 +-
 .../streaming/runtime/io/DataInputStatus.java      |   3 +
 .../runtime/io/MultipleInputSelectionHandler.java  |   7 +-
 .../io/checkpointing/CheckpointedInputGate.java    |   2 +-
 .../runtime/tasks/FinishedOperatorChain.java       |   4 +-
 .../runtime/tasks/MultipleInputStreamTask.java     |  23 +--
 .../streaming/runtime/tasks/OperatorChain.java     |   6 +-
 .../runtime/tasks/RegularOperatorChain.java        |   6 +-
 .../runtime/tasks/SourceOperatorStreamTask.java    |  34 ++--
 .../streaming/runtime/tasks/SourceStreamTask.java  |  88 +++++-----
 .../runtime/tasks/StreamOperatorWrapper.java       |  26 ++-
 .../flink/streaming/runtime/tasks/StreamTask.java  | 139 +++------------
 .../consumer/StreamTestSingleInputGate.java        |   4 +-
 .../api/operators/SourceOperatorTest.java          |   3 +-
 .../streaming/runtime/io/MockIndexedInputGate.java |   4 +-
 .../flink/streaming/runtime/io/MockInputGate.java  |   2 +-
 .../AlignedCheckpointsMassiveRandomTest.java       |   4 +-
 ...tStreamTaskChainedSourcesCheckpointingTest.java |  17 +-
 .../runtime/tasks/MultipleInputStreamTaskTest.java |  10 +-
 .../tasks/SourceOperatorStreamTaskTest.java        |  11 +-
 .../runtime/tasks/SourceStreamTaskTest.java        | 189 +--------------------
 .../runtime/tasks/SourceTaskTerminationTest.java   |   9 +-
 .../runtime/tasks/StreamOperatorWrapperTest.java   |   5 +-
 .../tasks/StreamTaskFinalCheckpointsTest.java      | 161 +++---------------
 .../streaming/runtime/tasks/StreamTaskTest.java    |  14 +-
 .../runtime/tasks/StreamTaskTestHarness.java       |   3 +-
 .../runtime/tasks/SynchronousCheckpointTest.java   |  19 ---
 .../runtime/tasks/TwoInputStreamTaskTest.java      |   4 +-
 .../JobMasterStopWithSavepointITCase.java          |  43 -----
 .../operators/lifecycle/BoundedSourceITCase.java   |   2 +-
 .../lifecycle/PartiallyFinishedSourcesITCase.java  |   2 +-
 .../lifecycle/StopWithSavepointITCase.java         |  11 +-
 .../operators/lifecycle/graph/TestEventSource.java |  13 --
 .../validation/TestJobDataFlowValidator.java       |  81 ++++++---
 .../flink/test/checkpointing/SavepointITCase.java  |  14 +-
 .../test/scheduling/AdaptiveSchedulerITCase.java   |  12 +-
 60 files changed, 552 insertions(+), 752 deletions(-)
 copy flink-table/flink-table-common/src/main/java/org/apache/flink/table/types/DataTypeQueryable.java => flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/StopMode.java (69%)

[flink] 04/04: [FLINK-23532] Remove unnecessary StreamTask#finishTask

Posted by dw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 2b167ae9764c02d4c77a41f2afd056ed8fd5f04f
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Tue Nov 30 14:47:36 2021 +0100

    [FLINK-23532] Remove unnecessary StreamTask#finishTask
---
 .../runtime/tasks/SourceOperatorStreamTask.java     |  5 -----
 .../streaming/runtime/tasks/SourceStreamTask.java   | 21 ++++-----------------
 .../flink/streaming/runtime/tasks/StreamTask.java   | 11 -----------
 .../jobmaster/JobMasterStopWithSavepointITCase.java | 13 -------------
 4 files changed, 4 insertions(+), 46 deletions(-)

diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
index b6b8c6c..f0715bc 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
@@ -108,11 +108,6 @@ public class SourceOperatorStreamTask<T> extends StreamTask<T, SourceOperator<T,
     }
 
     @Override
-    protected void finishTask() throws Exception {
-        mailboxProcessor.allActionsCompleted();
-    }
-
-    @Override
     public CompletableFuture<Boolean> triggerCheckpointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
         if (!isExternallyInducedSource) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
index 1a94c19..e10934c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
@@ -202,32 +202,19 @@ public class SourceStreamTask<
     @Override
     protected void cancelTask() {
         if (stopped.compareAndSet(false, true)) {
-            cancelOperator(true);
+            cancelOperator();
         }
     }
 
-    @Override
-    protected void finishTask() {
-        this.finishingReason = FinishingReason.STOP_WITH_SAVEPOINT_NO_DRAIN;
-        /**
-         * Currently stop with savepoint relies on the EndOfPartitionEvents propagation and performs
-         * clean shutdown after the stop with savepoint (which can produce some records to process
-         * after the savepoint while stopping). If we interrupt source thread, we might leave the
-         * network stack in an inconsistent state. So, if we want to relay on the clean shutdown, we
-         * can not interrupt the source thread.
-         */
-        cancelOperator(false);
-    }
-
-    private void cancelOperator(boolean interruptThread) {
+    private void cancelOperator() {
         try {
             if (mainOperator != null) {
                 mainOperator.cancel();
             }
         } finally {
-            if (sourceThread.isAlive() && interruptThread) {
+            if (sourceThread.isAlive()) {
                 interruptSourceThread();
-            } else if (!sourceThread.isAlive() && !sourceThread.getCompletionFuture().isDone()) {
+            } else if (!sourceThread.getCompletionFuture().isDone()) {
                 // sourceThread not alive and completion future not done means source thread
                 // didn't start and we need to manually complete the future
                 sourceThread.getCompletionFuture().complete(null);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 558a60f..1731214 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -609,17 +609,6 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
      */
     protected void advanceToEndOfEventTime() throws Exception {}
 
-    /**
-     * Instructs the task to go through its normal termination routine, i.e. exit the run-loop and
-     * call {@link StreamOperator#finish()} and {@link StreamOperator#close()} on its operators.
-     *
-     * <p>This is used by the source task to get out of the run-loop when the job is stopped with a
-     * savepoint.
-     *
-     * <p>For tasks other than the source task, this method does nothing.
-     */
-    protected void finishTask() throws Exception {}
-
     // ------------------------------------------------------------------------
     //  Core work methods of the Stream Task
     // ------------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
index 1ecdce8..c9a1cb9 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
@@ -338,11 +338,6 @@ public class JobMasterStopWithSavepointITCase extends AbstractTestBase {
                 long checkpointId, long latestCompletedCheckpointId) {
             return CompletableFuture.completedFuture(null);
         }
-
-        @Override
-        protected void finishTask() {
-            mailboxProcessor.allActionsCompleted();
-        }
     }
 
     /** A {@link StreamTask} that simply waits to be terminated normally. */
@@ -364,14 +359,6 @@ public class JobMasterStopWithSavepointITCase extends AbstractTestBase {
                 mailboxProcessor.suspend();
             }
         }
-
-        @Override
-        public void finishTask() throws Exception {
-            finishingLatch.await();
-            if (suspension != null) {
-                suspension.resume();
-            }
-        }
     }
 
     /**

[flink] 01/04: [FLINK-23532] Pass a flag for draining along with EndOfData

Posted by dw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 72a2471fdb08a625cbe173ef89b53db8425a14b6
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Tue Nov 30 14:13:25 2021 +0100

    [FLINK-23532] Pass a flag for draining along with EndOfData
    
    For the sake of unification, we want to emit EndOfData in case of stop-with-savepoint both with and without drain. This is a preparation so that we can enclose the flag inside of EndOfData.
---
 .../flink/runtime/io/PullingAsyncDataInput.java    | 10 +++++
 .../flink/runtime/io/network/api/EndOfData.java    | 44 +++++++++++++++------
 .../network/api/serialization/EventSerializer.java | 11 +++++-
 .../network/api/writer/ResultPartitionWriter.java  |  5 ++-
 .../partition/BoundedBlockingResultPartition.java  |  4 +-
 .../partition/PipelinedResultPartition.java        |  4 +-
 .../io/network/partition/ResultPartition.java      |  2 +-
 .../partition/SortMergeResultPartition.java        |  4 +-
 .../io/network/partition/consumer/InputGate.java   | 10 +++++
 .../partition/consumer/SingleInputGate.java        |  7 ++++
 .../network/partition/consumer/UnionInputGate.java |  7 ++++
 ...bleNotifyingResultPartitionWriterDecorator.java |  4 +-
 .../runtime/taskmanager/InputGateWithMetrics.java  |  5 +++
 .../api/serialization/EventSerializerTest.java     |  3 +-
 ...cordOrEventCollectingResultPartitionWriter.java |  4 +-
 .../netty/PartitionRequestServerHandlerTest.java   |  2 +-
 .../BoundedBlockingSubpartitionWriteReadTest.java  |  2 +-
 .../partition/MockResultPartitionWriter.java       |  2 +-
 .../io/network/partition/ResultPartitionTest.java  |  4 +-
 .../partition/consumer/SingleInputGateTest.java    | 46 ++++++++++++++++++++++
 .../partition/consumer/TestInputChannel.java       |  6 ++-
 .../partition/consumer/UnionInputGateTest.java     | 43 ++++++++++++++++++++
 .../runtime/io/AbstractStreamTaskNetworkInput.java |  4 +-
 .../streaming/runtime/io/DataInputStatus.java      |  3 ++
 .../runtime/io/MultipleInputSelectionHandler.java  |  7 +++-
 .../io/checkpointing/CheckpointedInputGate.java    |  5 +++
 .../flink/streaming/runtime/tasks/StreamTask.java  |  4 +-
 .../consumer/StreamTestSingleInputGate.java        |  2 +-
 .../streaming/runtime/io/MockIndexedInputGate.java |  5 +++
 .../flink/streaming/runtime/io/MockInputGate.java  |  5 +++
 .../AlignedCheckpointsMassiveRandomTest.java       |  5 +++
 ...tStreamTaskChainedSourcesCheckpointingTest.java | 17 ++++----
 .../runtime/tasks/MultipleInputStreamTaskTest.java |  9 +++--
 .../tasks/SourceOperatorStreamTaskTest.java        | 11 +++---
 .../runtime/tasks/SourceStreamTaskTest.java        |  7 ++--
 .../runtime/tasks/SourceTaskTerminationTest.java   |  2 +-
 .../tasks/StreamTaskFinalCheckpointsTest.java      | 16 ++++----
 .../runtime/tasks/StreamTaskTestHarness.java       |  2 +-
 .../runtime/tasks/TwoInputStreamTaskTest.java      |  3 +-
 39 files changed, 269 insertions(+), 67 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
index fa87f13..73869cb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
@@ -18,6 +18,7 @@
 package org.apache.flink.runtime.io;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.io.network.api.EndOfData;
 
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
@@ -65,4 +66,13 @@ public interface PullingAsyncDataInput<T> extends AvailabilityProvider {
      *     point
      */
     boolean hasReceivedEndOfData();
+
+    /**
+     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
+     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
+     * should not drain the current task. See also {@code StopMode}.
+     *
+     * <p>We should check the {@link #hasReceivedEndOfData()} first.
+     */
+    boolean shouldDrainOnEndOfData();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
index 7421f90..4e244f9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
@@ -23,6 +23,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.event.RuntimeEvent;
 
 import java.io.IOException;
+import java.util.Objects;
 
 /**
  * This event indicates there will be no more data records in a subpartition. There still might be
@@ -35,36 +36,57 @@ import java.io.IOException;
  */
 public class EndOfData extends RuntimeEvent {
 
-    /** The singleton instance of this event. */
-    public static final EndOfData INSTANCE = new EndOfData();
+    private final boolean shouldDrain;
 
     // ------------------------------------------------------------------------
 
-    // not instantiable
-    private EndOfData() {}
+    public EndOfData(boolean shouldDrain) {
+        this.shouldDrain = shouldDrain;
+    }
+
+    public boolean shouldDrain() {
+        return shouldDrain;
+    }
 
     // ------------------------------------------------------------------------
 
+    //
+    //  These methods are inherited form the generic serialization of AbstractEvent
+    //  but would require the CheckpointBarrier to be mutable. Since all serialization
+    //  for events goes through the EventSerializer class, which has special serialization
+    //  for the CheckpointBarrier, we don't need these methods
+    //
     @Override
-    public void write(DataOutputView out) throws IOException {}
+    public void write(DataOutputView out) throws IOException {
+        throw new UnsupportedOperationException("This method should never be called");
+    }
 
     @Override
-    public void read(DataInputView in) throws IOException {}
+    public void read(DataInputView in) throws IOException {
+        throw new UnsupportedOperationException("This method should never be called");
+    }
 
     // ------------------------------------------------------------------------
 
     @Override
-    public int hashCode() {
-        return 1965146684;
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        EndOfData endOfData = (EndOfData) o;
+        return shouldDrain == endOfData.shouldDrain;
     }
 
     @Override
-    public boolean equals(Object obj) {
-        return obj != null && obj.getClass() == EndOfData.class;
+    public int hashCode() {
+        return Objects.hash(shouldDrain);
     }
 
     @Override
     public String toString() {
-        return getClass().getSimpleName();
+        return "EndOfData{shouldDrain=" + shouldDrain + '}';
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
index bca2feb..46ee08d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
@@ -94,7 +94,14 @@ public class EventSerializer {
         } else if (eventClass == EndOfChannelStateEvent.class) {
             return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_CHANNEL_STATE_EVENT});
         } else if (eventClass == EndOfData.class) {
-            return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_USER_RECORDS_EVENT});
+            return ByteBuffer.wrap(
+                    new byte[] {
+                        0,
+                        0,
+                        0,
+                        END_OF_USER_RECORDS_EVENT,
+                        ((EndOfData) event).shouldDrain() ? (byte) 1 : (byte) 0
+                    });
         } else if (eventClass == CancelCheckpointMarker.class) {
             CancelCheckpointMarker marker = (CancelCheckpointMarker) event;
 
@@ -157,7 +164,7 @@ public class EventSerializer {
             } else if (type == END_OF_CHANNEL_STATE_EVENT) {
                 return EndOfChannelStateEvent.INSTANCE;
             } else if (type == END_OF_USER_RECORDS_EVENT) {
-                return EndOfData.INSTANCE;
+                return new EndOfData(buffer.get() == 1);
             } else if (type == CANCEL_CHECKPOINT_MARKER_EVENT) {
                 long id = buffer.getLong();
                 return new CancelCheckpointMarker(id);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index bd9a2ef..ff0160a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -69,8 +69,11 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
     /**
      * Notifies the downstream tasks that this {@code ResultPartitionWriter} have emitted all the
      * user records.
+     *
+     * @param shouldDrain tells if we should flush all records or not (it is false in case of
+     *     stop-with-savepoint (--no-drain))
      */
-    void notifyEndOfData() throws IOException;
+    void notifyEndOfData(boolean shouldDrain) throws IOException;
 
     /**
      * Gets the future indicating whether all the records has been processed by the downstream
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
index a3f1372..9b6a724 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
@@ -66,9 +66,9 @@ public class BoundedBlockingResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         if (!hasNotifiedEndOfUserRecords) {
-            broadcastEvent(EndOfData.INSTANCE, false);
+            broadcastEvent(new EndOfData(shouldDrain), false);
             hasNotifiedEndOfUserRecords = true;
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
index f94d9ad..dd00bb1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
@@ -193,10 +193,10 @@ public class PipelinedResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(EndOfData.INSTANCE, false);
+                broadcastEvent(new EndOfData(shouldDrain), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 1207fed..87561ec 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -200,7 +200,7 @@ public abstract class ResultPartition implements ResultPartitionWriter {
     // ------------------------------------------------------------------------
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         throw new UnsupportedOperationException();
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index bc214ce..a8e8d31 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -398,10 +398,10 @@ public class SortMergeResultPartition extends ResultPartition {
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(EndOfData.INSTANCE, false);
+                broadcastEvent(new EndOfData(shouldDrain), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
index 2253676..012aec6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.PullingAsyncDataInput;
+import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.partition.ChannelStateHolder;
 
 import java.io.IOException;
@@ -101,6 +102,15 @@ public abstract class InputGate
     public abstract boolean hasReceivedEndOfData();
 
     /**
+     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
+     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
+     * should not drain the current task. See also {@code StopMode}.
+     *
+     * <p>We should check the {@link #hasReceivedEndOfData()} first.
+     */
+    public abstract boolean shouldDrainOnEndOfData();
+
+    /**
      * Blocking call waiting for next {@link BufferOrEvent}.
      *
      * <p>Note: It should be guaranteed that the previous returned buffer has been recycled before
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index f560ec4..ee78ff6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -208,6 +208,7 @@ public class SingleInputGate extends IndexedInputGate {
 
     private final ThroughputCalculator throughputCalculator;
     private final BufferDebloater bufferDebloater;
+    private boolean shouldDrainOnEndOfData = true;
 
     public SingleInputGate(
             String owningTaskName,
@@ -671,6 +672,11 @@ public class SingleInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return shouldDrainOnEndOfData;
+    }
+
+    @Override
     public String toString() {
         return "SingleInputGate{"
                 + "owningTaskName='"
@@ -849,6 +855,7 @@ public class SingleInputGate extends IndexedInputGate {
                 channelsWithEndOfUserRecords.set(currentChannel.getChannelIndex());
                 hasReceivedEndOfData =
                         channelsWithEndOfUserRecords.cardinality() == numberOfInputChannels;
+                shouldDrainOnEndOfData &= ((EndOfData) event).shouldDrain();
             }
         }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index ce8b5e8..4154615 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -76,6 +76,7 @@ public class UnionInputGate extends InputGate {
 
     private final Set<IndexedInputGate> inputGatesWithRemainingUserData;
 
+    private boolean shouldDrainOnEndOfData = true;
     /**
      * Gates, which notified this input gate about available data. We are using it as a FIFO queue
      * of {@link InputGate}s to avoid starvation and provide some basic fairness.
@@ -185,6 +186,11 @@ public class UnionInputGate extends InputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return shouldDrainOnEndOfData;
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() throws IOException, InterruptedException {
         return getNextBufferOrEvent(true);
     }
@@ -289,6 +295,7 @@ public class UnionInputGate extends InputGate {
                 && bufferOrEvent.getEvent().getClass() == EndOfData.class
                 && inputGate.hasReceivedEndOfData()) {
 
+            shouldDrainOnEndOfData &= inputGate.shouldDrainOnEndOfData();
             if (!inputGatesWithRemainingUserData.remove(inputGate)) {
                 throw new IllegalStateException(
                         "Couldn't find input gate in set of remaining input gates.");
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
index f498ada..a7d425c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
@@ -149,8 +149,8 @@ public class ConsumableNotifyingResultPartitionWriterDecorator {
         }
 
         @Override
-        public void notifyEndOfData() throws IOException {
-            partitionWriter.notifyEndOfData();
+        public void notifyEndOfData(boolean shouldDrain) throws IOException {
+            partitionWriter.notifyEndOfData(shouldDrain);
         }
 
         @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
index 2a09d5f..dae257d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
@@ -101,6 +101,11 @@ public class InputGateWithMetrics extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return inputGate.shouldDrainOnEndOfData();
+    }
+
+    @Override
     public void setup() throws IOException {
         inputGate.setup();
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
index 5445a02..6a2ef20 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
@@ -49,7 +49,8 @@ public class EventSerializerTest {
     private final AbstractEvent[] events = {
         EndOfPartitionEvent.INSTANCE,
         EndOfSuperstepEvent.INSTANCE,
-        EndOfData.INSTANCE,
+        new EndOfData(true),
+        new EndOfData(false),
         new CheckpointBarrier(
                 1678L,
                 4623784L,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
index b95febb..bd2b667 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
@@ -89,9 +89,9 @@ public class RecordOrEventCollectingResultPartitionWriter<T>
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         if (collectNetworkEvents) {
-            broadcastEvent(EndOfData.INSTANCE, false);
+            broadcastEvent(new EndOfData(shouldDrain), false);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
index b0fcf42..9e1973f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
@@ -128,7 +128,7 @@ public class PartitionRequestServerHandlerTest extends TestLogger {
         partitionRequestQueue.notifyReaderCreated(viewReader);
 
         // Write the message to acknowledge all records are processed to server
-        resultPartition.notifyEndOfData();
+        resultPartition.notifyEndOfData(true);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 resultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
index ed9b181..855ff82 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
@@ -256,7 +256,7 @@ public class BoundedBlockingSubpartitionWriteReadTest {
 
     private void writeEndOfData(BoundedBlockingSubpartition subpartition) throws IOException {
         try (BufferConsumer eventBufferConsumer =
-                EventSerializer.toBufferConsumer(EndOfData.INSTANCE, false)) {
+                EventSerializer.toBufferConsumer(new EndOfData(true), false)) {
             // Retain the buffer so that it can be recycled by each channel of targetPartition
             subpartition.add(eventBufferConsumer.copy(), 0);
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
index dc85d1c..6db08bd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
@@ -61,7 +61,7 @@ public class MockResultPartitionWriter implements ResultPartitionWriter {
     public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException {}
 
     @Override
-    public void notifyEndOfData() throws IOException {}
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {}
 
     @Override
     public CompletableFuture<Void> getAllDataProcessedFuture() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index e522bd4..06ac632 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -625,7 +625,7 @@ public class ResultPartitionTest {
         BufferWritingResultPartition bufferWritingResultPartition =
                 createResultPartition(ResultPartitionType.PIPELINED_BOUNDED);
 
-        bufferWritingResultPartition.notifyEndOfData();
+        bufferWritingResultPartition.notifyEndOfData(true);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 bufferWritingResultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
@@ -634,7 +634,7 @@ public class ResultPartitionTest {
             Buffer nextBuffer = ((PipelinedSubpartition) resultSubpartition).pollBuffer().buffer();
             assertFalse(nextBuffer.isBuffer());
             assertEquals(
-                    EndOfData.INSTANCE,
+                    new EndOfData(true),
                     EventSerializer.fromBuffer(nextBuffer, getClass().getClassLoader()));
         }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index c543014..1ea926ac 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -254,6 +254,52 @@ public class SingleInputGateTest extends InputGateTestBase {
         }
     }
 
+    @Test
+    public void testDrainFlagComputation() throws Exception {
+        // Setup
+        final SingleInputGate inputGate1 = createInputGate();
+        final SingleInputGate inputGate2 = createInputGate();
+
+        final TestInputChannel[] inputChannels1 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate1, 0), new TestInputChannel(inputGate1, 1)
+                };
+        inputGate1.setInputChannels(inputChannels1);
+        final TestInputChannel[] inputChannels2 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate2, 0), new TestInputChannel(inputGate2, 1)
+                };
+        inputGate2.setInputChannels(inputChannels2);
+
+        // Test
+        inputChannels1[1].readEndOfData(true);
+        inputChannels1[0].readEndOfData(false);
+
+        inputChannels2[1].readEndOfData(true);
+        inputChannels2[0].readEndOfData(true);
+
+        inputGate1.notifyChannelNonEmpty(inputChannels1[0]);
+        inputGate1.notifyChannelNonEmpty(inputChannels1[1]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[0]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[1]);
+
+        verifyBufferOrEvent(inputGate1, false, 0, true);
+        // we have received EndOfData on a single channel only
+        assertFalse(inputGate1.hasReceivedEndOfData());
+        verifyBufferOrEvent(inputGate1, false, 1, true);
+        assertTrue(inputGate1.hasReceivedEndOfData());
+        // one of the channels said we should not drain
+        assertFalse(inputGate1.shouldDrainOnEndOfData());
+
+        verifyBufferOrEvent(inputGate2, false, 0, true);
+        // we have received EndOfData on a single channel only
+        assertFalse(inputGate2.hasReceivedEndOfData());
+        verifyBufferOrEvent(inputGate2, false, 1, true);
+        assertTrue(inputGate2.hasReceivedEndOfData());
+        // both channels said we should drain
+        assertTrue(inputGate2.shouldDrainOnEndOfData());
+    }
+
     /**
      * Tests that the compressed buffer will be decompressed after calling {@link
      * SingleInputGate#getNext()}.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
index c89b470..3011d3c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
@@ -107,9 +107,13 @@ public class TestInputChannel extends InputChannel {
     }
 
     TestInputChannel readEndOfData() throws IOException {
+        return readEndOfData(true);
+    }
+
+    TestInputChannel readEndOfData(boolean shouldDrain) throws IOException {
         addBufferAndAvailability(
                 new BufferAndAvailability(
-                        EventSerializer.toBuffer(EndOfData.INSTANCE, false),
+                        EventSerializer.toBuffer(new EndOfData(shouldDrain), false),
                         Buffer.DataType.EVENT_BUFFER,
                         0,
                         sequenceNumber++));
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index b5bd5e6..c9ca40e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -133,6 +133,49 @@ public class UnionInputGateTest extends InputGateTestBase {
     }
 
     @Test
+    public void testDrainFlagComputation() throws Exception {
+        // Setup
+        final SingleInputGate inputGate1 = createInputGate();
+        final SingleInputGate inputGate2 = createInputGate();
+
+        final TestInputChannel[] inputChannels1 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate1, 0), new TestInputChannel(inputGate1, 1)
+                };
+        inputGate1.setInputChannels(inputChannels1);
+        final TestInputChannel[] inputChannels2 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate2, 0), new TestInputChannel(inputGate2, 1)
+                };
+        inputGate2.setInputChannels(inputChannels2);
+
+        // Test
+        inputChannels1[1].readEndOfData(true);
+        inputChannels1[0].readEndOfData(false);
+
+        inputChannels2[1].readEndOfData(true);
+        inputChannels2[0].readEndOfData(true);
+
+        final UnionInputGate unionInputGate = new UnionInputGate(inputGate1, inputGate2);
+
+        inputGate1.notifyChannelNonEmpty(inputChannels1[0]);
+        inputGate1.notifyChannelNonEmpty(inputChannels1[1]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[0]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[1]);
+
+        verifyBufferOrEvent(unionInputGate, false, 0, true);
+        verifyBufferOrEvent(unionInputGate, false, 2, true);
+        // we have received EndOfData on a single input only
+        assertFalse(unionInputGate.hasReceivedEndOfData());
+
+        verifyBufferOrEvent(unionInputGate, false, 1, true);
+        verifyBufferOrEvent(unionInputGate, false, 3, true);
+        // both channels received EndOfData, one channel said we should not drain
+        assertTrue(unionInputGate.hasReceivedEndOfData());
+        assertFalse(unionInputGate.shouldDrainOnEndOfData());
+    }
+
+    @Test
     public void testIsAvailable() throws Exception {
         final SingleInputGate inputGate1 = createInputGate(1);
         TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
index 8e84865..d60c824 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
@@ -152,7 +152,9 @@ public abstract class AbstractStreamTaskNetworkInput<
         final AbstractEvent event = bufferOrEvent.getEvent();
         if (event.getClass() == EndOfData.class) {
             if (checkpointedInputGate.hasReceivedEndOfData()) {
-                return DataInputStatus.END_OF_DATA;
+                return checkpointedInputGate.shouldDrainOnEndOfData()
+                        ? DataInputStatus.END_OF_DATA
+                        : DataInputStatus.STOPPED;
             }
         } else if (event.getClass() == EndOfPartitionEvent.class) {
             // release the record deserializer immediately,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
index 8fe97ff..d1e45fa 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
@@ -44,6 +44,9 @@ public enum DataInputStatus {
     /** Indicator that all persisted data of the data exchange has been successfully restored. */
     END_OF_RECOVERY,
 
+    /** Indicator that the input was stopped because of stop-with-savepoint without drain. */
+    STOPPED,
+
     /** Indicator that the input has reached the end of data. */
     END_OF_DATA,
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
index 31be2b7..3405d7a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
@@ -49,6 +49,8 @@ public class MultipleInputSelectionHandler {
 
     private long dataFinishedButNotPartition;
 
+    private boolean drainOnEndOfData = true;
+
     private enum OperatingMode {
         NO_INPUT_SELECTABLE,
         INPUT_SELECTABLE_PRESENT_NO_DATA_INPUTS_FINISHED,
@@ -91,6 +93,9 @@ public class MultipleInputSelectionHandler {
             case NOTHING_AVAILABLE:
                 availableInputsMask = unsetBitMask(availableInputsMask, inputIndex);
                 break;
+            case STOPPED:
+                this.drainOnEndOfData = false;
+                // fall through
             case END_OF_DATA:
                 dataFinishedButNotPartition = setBitMask(dataFinishedButNotPartition, inputIndex);
                 updateModeOnEndOfData();
@@ -126,7 +131,7 @@ public class MultipleInputSelectionHandler {
 
         if (updatedStatus == DataInputStatus.END_OF_DATA
                 && this.operatingMode == OperatingMode.ALL_DATA_INPUTS_FINISHED) {
-            return DataInputStatus.END_OF_DATA;
+            return drainOnEndOfData ? DataInputStatus.END_OF_DATA : DataInputStatus.STOPPED;
         }
 
         if (isAnyInputAvailable()) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
index 048d446..b497e4d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
@@ -227,6 +227,11 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
         return inputGate.hasReceivedEndOfData();
     }
 
+    @Override
+    public boolean shouldDrainOnEndOfData() {
+        return inputGate.shouldDrainOnEndOfData();
+    }
+
     /**
      * Cleans up all internally held resources.
      *
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 2003c7b..06c6874 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -526,6 +526,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                 break;
             case END_OF_RECOVERY:
                 throw new IllegalStateException("We should not receive this event here.");
+            case STOPPED:
+                throw new UnsupportedOperationException("Not supported yet");
             case END_OF_DATA:
                 endData();
                 return;
@@ -575,7 +577,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
         this.finishedOperators = true;
 
         for (ResultPartitionWriter partitionWriter : getEnvironment().getAllWriters()) {
-            partitionWriter.notifyEndOfData();
+            partitionWriter.notifyEndOfData(true);
         }
 
         this.endOfDataReceived = true;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
index cc29517..32b6e96 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
@@ -124,7 +124,7 @@ public class StreamTestSingleInputGate<T> {
                         } else if (input != null && input.isDataEnd()) {
                             return Optional.of(
                                     new BufferAndAvailability(
-                                            EventSerializer.toBuffer(EndOfData.INSTANCE, false),
+                                            EventSerializer.toBuffer(new EndOfData(true), false),
                                             nextType,
                                             0,
                                             0));
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
index d7e7ae1..8bbe8fa 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
@@ -101,6 +101,11 @@ public class MockIndexedInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return false;
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() {
         throw new UnsupportedOperationException();
     }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index 43c4b97..993361e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -119,6 +119,11 @@ public class MockInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        throw new UnsupportedOperationException("Not implemented yet");
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() {
         BufferOrEvent next = bufferOrEvents.poll();
         if (!finishAfterLastBuffer && bufferOrEvents.isEmpty()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
index f35bd70..903ec71 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
@@ -173,6 +173,11 @@ public class AlignedCheckpointsMassiveRandomTest {
         }
 
         @Override
+        public boolean shouldDrainOnEndOfData() {
+            return false;
+        }
+
+        @Override
         public InputChannel getChannel(int channelIndex) {
             throw new UnsupportedOperationException();
         }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
index 9ceeae7..d433b69 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
@@ -281,10 +281,10 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             CheckpointBarrier barrier = createStopWithSavepointDrainBarrier();
 
             testHarness.processElement(new StreamRecord<>("44", TimestampAssigner.NO_TIMESTAMP), 0);
-            testHarness.processEvent(EndOfData.INSTANCE, 0);
+            testHarness.processEvent(new EndOfData(true), 0);
             testHarness.processEvent(barrier, 0);
             testHarness.processElement(new StreamRecord<>(47d, TimestampAssigner.NO_TIMESTAMP), 1);
-            testHarness.processEvent(EndOfData.INSTANCE, 1);
+            testHarness.processEvent(new EndOfData(true), 1);
             testHarness.processEvent(barrier, 1);
 
             addSourceRecords(testHarness, 1, Boundedness.CONTINUOUS_UNBOUNDED, 1, 2);
@@ -311,7 +311,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                     containsInAnyOrder(expectedOutput.toArray()));
             assertThat(
                     actualOutput.subList(actualOutput.size() - 3, actualOutput.size()),
-                    contains(new StreamRecord<>("FINISH"), EndOfData.INSTANCE, barrier));
+                    contains(new StreamRecord<>("FINISH"), new EndOfData(true), barrier));
         }
     }
 
@@ -435,8 +435,8 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                 testHarness.processAll();
 
                 // The checkpoint 2 would be aligned after received all the EndOfPartitionEvent.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
-                testHarness.processEvent(EndOfData.INSTANCE, 1, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(true), 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.getTaskStateManager().getWaitForReportLatch().await();
@@ -493,8 +493,9 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .addSourceInput(
@@ -522,7 +523,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             testHarness.processElement(Watermark.MAX_WATERMARK);
             assertThat(output, is(empty()));
             testHarness.waitForTaskCompletion();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             for (StreamOperatorWrapper<?, ?> wrapper :
                     testHarness.getStreamTask().operatorChain.getAllOperators()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
index f8eb733..c31dd3d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
@@ -989,15 +989,15 @@ public class MultipleInputStreamTaskTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
                 assertEquals(4, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after all the inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 1, 0);
-                testHarness.processEvent(EndOfData.INSTANCE, 2, 0);
+                testHarness.processEvent(new EndOfData(true), 1, 0);
+                testHarness.processEvent(new EndOfData(true), 2, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 2, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -1058,7 +1058,8 @@ public class MultipleInputStreamTaskTest {
             testHarness.processElement(Watermark.MAX_WATERMARK, 2);
             testHarness.waitForTaskCompletion();
             assertThat(
-                    testHarness.getOutput(), contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+                    testHarness.getOutput(),
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
index ffe4c15..d4fa9be 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
@@ -129,7 +129,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(EndOfData.INSTANCE);
+            expectedOutput.add(new EndOfData(true));
             expectedOutput.add(
                     new CheckpointBarrier(checkpointId, checkpointId, checkpointOptions));
 
@@ -145,7 +145,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(EndOfData.INSTANCE);
+            expectedOutput.add(new EndOfData(true));
             assertThat(testHarness.getOutput().toArray(), equalTo(expectedOutput.toArray()));
         }
     }
@@ -203,8 +203,9 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .setupOperatorChain(sourceOperatorFactory)
@@ -214,7 +215,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             testHarness.getStreamTask().invoke();
             testHarness.processAll();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             LifeCycleMonitorSourceReader sourceReader =
                     (LifeCycleMonitorSourceReader)
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
index 5fbc65c..6553279 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
@@ -720,8 +720,9 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .setupOperatorChain(new StreamSource<>(testSource))
@@ -732,7 +733,7 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
             harness.processAll();
             harness.streamTask.getCompletionFuture().get();
 
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             LifeCycleMonitorSource source =
                     (LifeCycleMonitorSource)
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
index 8cf461c..5a67e86 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
@@ -107,7 +107,7 @@ public class SourceTaskTerminationTest extends TestLogger {
                 // if we are in TERMINATE mode, we expect the source task
                 // to emit MAX_WM before the SYNC_SAVEPOINT barrier.
                 verifyWatermark(srcTaskTestHarness.getOutput(), Watermark.MAX_WATERMARK);
-                verifyEvent(srcTaskTestHarness.getOutput(), EndOfData.INSTANCE);
+                verifyEvent(srcTaskTestHarness.getOutput(), new EndOfData(true));
             }
 
             verifyCheckpointBarrier(srcTaskTestHarness.getOutput(), syncSavepointId);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
index efd0ea1..88087d5 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
@@ -155,7 +155,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -163,8 +163,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 1);
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 2);
+                testHarness.processEvent(new EndOfData(true), 0, 1);
+                testHarness.processEvent(new EndOfData(true), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, lastCheckpointId);
@@ -664,7 +664,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertArrayEquals(new int[] {0, 0, 0}, resumedCount);
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -673,8 +673,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 1);
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 2);
+                testHarness.processEvent(new EndOfData(true), 0, 1);
+                testHarness.processEvent(new EndOfData(true), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -759,7 +759,7 @@ public class StreamTaskFinalCheckpointsTest {
                 // The checkpoint is added to the mailbox and will be processed in the
                 // mailbox loop after call operators' finish method in the afterInvoke()
                 // method.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 checkpointFuture.thenAccept(
                         (ignored) -> {
@@ -947,7 +947,7 @@ public class StreamTaskFinalCheckpointsTest {
                                     checkpointMetaData.getTimestamp(),
                                     checkpointOptions),
                             Watermark.MAX_WATERMARK,
-                            EndOfData.INSTANCE));
+                            new EndOfData(true)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index 2bae6ed..a95b17c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -492,7 +492,7 @@ public class StreamTaskTestHarness<OUT> {
 
     public void endInput(int gateIndex, int channelIndex, boolean emitEndOfData) {
         if (emitEndOfData) {
-            inputGates[gateIndex].sendEvent(EndOfData.INSTANCE, channelIndex);
+            inputGates[gateIndex].sendEvent(new EndOfData(true), channelIndex);
         }
         inputGates[gateIndex].sendEvent(EndOfPartitionEvent.INSTANCE, channelIndex);
     }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index 2686ad1..ab8304f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -542,7 +542,8 @@ public class TwoInputStreamTaskTest {
             testHarness.processElement(Watermark.MAX_WATERMARK, 1);
             testHarness.waitForTaskCompletion();
             assertThat(
-                    testHarness.getOutput(), contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+                    testHarness.getOutput(),
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
         }
     }
 

[flink] 02/04: [hotfix] Remove duplicated handling of stop-with-savepoint trigger failure

Posted by dw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 5a0dcfb85c3d0be22e6154d63a24d35d7d69855e
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Thu Dec 2 12:57:01 2021 +0100

    [hotfix] Remove duplicated handling of stop-with-savepoint trigger failure
    
    After FLINK-23741 the failure is handled in Task#triggerCheckpointBarrier, therefore we no longer need to handle it inside of StreamTask. Moreover if a failure had occured in assertTriggeringCheckpointExceptions the failed checkpoint would have not been declined.
---
 .../runtime/tasks/MultipleInputStreamTask.java     |  8 ++------
 .../runtime/tasks/SourceOperatorStreamTask.java    |  8 ++------
 .../streaming/runtime/tasks/SourceStreamTask.java  | 14 ++++++-------
 .../flink/streaming/runtime/tasks/StreamTask.java  | 24 ----------------------
 4 files changed, 10 insertions(+), 44 deletions(-)

diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
index 89a5434..4f79660 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
@@ -251,12 +251,8 @@ public class MultipleInputStreamTask<OUT>
                 },
                 "stop chained Flip-27 source for stop-with-savepoint --drain");
 
-        return assertTriggeringCheckpointExceptions(
-                sourcesStopped.thenCompose(
-                        ignore ->
-                                triggerSourcesCheckpointAsync(
-                                        checkpointMetaData, checkpointOptions)),
-                checkpointMetaData.getCheckpointId());
+        return sourcesStopped.thenCompose(
+                ignore -> triggerSourcesCheckpointAsync(checkpointMetaData, checkpointOptions));
     }
 
     private void checkPendingCheckpointCompletedFuturesSize() {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
index 8d0bf97..6ce7b26 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
@@ -137,12 +137,8 @@ public class SourceOperatorStreamTask<T> extends StreamTask<T, SourceOperator<T,
                 },
                 "stop Flip-27 source for stop-with-savepoint --drain");
 
-        return assertTriggeringCheckpointExceptions(
-                operatorFinished.thenCompose(
-                        (ignore) ->
-                                super.triggerCheckpointAsync(
-                                        checkpointMetaData, checkpointOptions)),
-                checkpointMetaData.getCheckpointId());
+        return operatorFinished.thenCompose(
+                (ignore) -> super.triggerCheckpointAsync(checkpointMetaData, checkpointOptions));
     }
 
     @Override
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
index d51080e..0206ef8 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
@@ -286,14 +286,12 @@ public class SourceStreamTask<
                         stopOperatorForStopWithSavepointWithDrain(
                                 checkpointMetaData.getCheckpointId()),
                 "stop legacy source for stop-with-savepoint --drain");
-        return assertTriggeringCheckpointExceptions(
-                sourceThread
-                        .getCompletionFuture()
-                        .thenCompose(
-                                ignore ->
-                                        super.triggerCheckpointAsync(
-                                                checkpointMetaData, checkpointOptions)),
-                checkpointMetaData.getCheckpointId());
+        return sourceThread
+                .getCompletionFuture()
+                .thenCompose(
+                        ignore ->
+                                super.triggerCheckpointAsync(
+                                        checkpointMetaData, checkpointOptions));
     }
 
     private void stopOperatorForStopWithSavepointWithDrain(long checkpointId) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 06c6874..42a1c2b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -102,7 +102,6 @@ import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.SerializedValue;
 import org.apache.flink.util.TernaryBoolean;
-import org.apache.flink.util.WrappingRuntimeException;
 import org.apache.flink.util.concurrent.ExecutorThreadFactory;
 import org.apache.flink.util.concurrent.FutureUtils;
 import org.apache.flink.util.function.RunnableWithException;
@@ -1211,29 +1210,6 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
         return true;
     }
 
-    protected final CompletableFuture<Boolean> assertTriggeringCheckpointExceptions(
-            CompletableFuture<Boolean> triggerFuture, long checkpointId) {
-        CompletableFuture<Boolean> checkpointTriggered =
-                triggerFuture.exceptionally(
-                        error -> {
-                            if (ExceptionUtils.findThrowable(
-                                            error, RejectedExecutionException.class)
-                                    .isPresent()) {
-                                // This may happen if the mailbox is closed. It means that
-                                // the task is shutting down, so we just ignore it.
-                                LOG.debug(
-                                        "Triggering checkpoint {} for {} was rejected by the mailbox",
-                                        checkpointId,
-                                        getTaskNameWithSubtaskAndId());
-                                return false;
-                            } else {
-                                throw new WrappingRuntimeException(error);
-                            }
-                        });
-        FutureUtils.assertNoException(checkpointTriggered);
-        return checkpointTriggered;
-    }
-
     /**
      * Acquires the optional {@link CheckpointBarrierHandler} associated with this stream task. The
      * {@code CheckpointBarrierHandler} should exist if the task has data inputs and requires to

[flink] 03/04: [FLINK-23532] Unify stop-with-savepoint without and with drain

Posted by dw...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 0bd397bf660bfcc0a554636a0a1e7c46d72859f5
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Tue Nov 30 14:46:26 2021 +0100

    [FLINK-23532] Unify stop-with-savepoint without and with drain
    
    We unify the stop-with-savepoint with and without drain. Both of those
    options follow a similar codepath.
    
    1. First of all we always emit EndOfData in case of a clean shutdown, but with a flag which determines if we should call StreamOperator#finish and flush all records.
    2. Secondly, we stop sources once we receive a trigger for synchronous checkpoint instead of shutting them down after the checkpoint completes.
    3. Lastly, we no longer wait in a synchronous mailbox loop, but instead we wait on a future in the StreamTask#afterInvoke.
---
 .../flink/runtime/io/PullingAsyncDataInput.java    |  24 +--
 .../flink/runtime/io/network/api/EndOfData.java    |  16 +-
 .../flink/runtime/io/network/api/StopMode.java     |  32 ++++
 .../network/api/serialization/EventSerializer.java |   5 +-
 .../network/api/writer/ResultPartitionWriter.java  |   5 +-
 .../partition/BoundedBlockingResultPartition.java  |   5 +-
 .../partition/PipelinedResultPartition.java        |   5 +-
 .../io/network/partition/ResultPartition.java      |   3 +-
 .../partition/SortMergeResultPartition.java        |   5 +-
 .../io/network/partition/consumer/InputGate.java   |  12 --
 .../partition/consumer/SingleInputGate.java        |  18 +-
 .../network/partition/consumer/UnionInputGate.java |  19 ++-
 ...bleNotifyingResultPartitionWriterDecorator.java |   5 +-
 .../runtime/taskmanager/InputGateWithMetrics.java  |   7 +-
 .../api/serialization/EventSerializerTest.java     |   5 +-
 ...cordOrEventCollectingResultPartitionWriter.java |   5 +-
 .../netty/PartitionRequestServerHandlerTest.java   |   3 +-
 .../BoundedBlockingSubpartitionWriteReadTest.java  |   3 +-
 .../partition/MockResultPartitionWriter.java       |   3 +-
 .../io/network/partition/ResultPartitionTest.java  |   5 +-
 .../partition/consumer/SingleInputGateTest.java    |  36 ++--
 .../partition/consumer/TestInputChannel.java       |   7 +-
 .../partition/consumer/UnionInputGateTest.java     |  25 ++-
 .../streaming/api/operators/SourceOperator.java    |  29 +++-
 .../runtime/io/AbstractStreamTaskNetworkInput.java |  12 +-
 .../io/checkpointing/CheckpointedInputGate.java    |   7 +-
 .../runtime/tasks/FinishedOperatorChain.java       |   4 +-
 .../runtime/tasks/MultipleInputStreamTask.java     |  15 +-
 .../streaming/runtime/tasks/OperatorChain.java     |   6 +-
 .../runtime/tasks/RegularOperatorChain.java        |   6 +-
 .../runtime/tasks/SourceOperatorStreamTask.java    |  19 ++-
 .../streaming/runtime/tasks/SourceStreamTask.java  |  53 +++---
 .../runtime/tasks/StreamOperatorWrapper.java       |  26 ++-
 .../flink/streaming/runtime/tasks/StreamTask.java  | 104 +++--------
 .../consumer/StreamTestSingleInputGate.java        |   4 +-
 .../api/operators/SourceOperatorTest.java          |   3 +-
 .../streaming/runtime/io/MockIndexedInputGate.java |   9 +-
 .../flink/streaming/runtime/io/MockInputGate.java  |   7 +-
 .../AlignedCheckpointsMassiveRandomTest.java       |   9 +-
 ...tStreamTaskChainedSourcesCheckpointingTest.java |  18 +-
 .../runtime/tasks/MultipleInputStreamTaskTest.java |   9 +-
 .../tasks/SourceOperatorStreamTaskTest.java        |  12 +-
 .../runtime/tasks/SourceStreamTaskTest.java        | 190 +--------------------
 .../runtime/tasks/SourceTaskTerminationTest.java   |   9 +-
 .../runtime/tasks/StreamOperatorWrapperTest.java   |   5 +-
 .../tasks/StreamTaskFinalCheckpointsTest.java      | 161 +++--------------
 .../streaming/runtime/tasks/StreamTaskTest.java    |  14 +-
 .../runtime/tasks/StreamTaskTestHarness.java       |   3 +-
 .../runtime/tasks/SynchronousCheckpointTest.java   |  19 ---
 .../runtime/tasks/TwoInputStreamTaskTest.java      |   3 +-
 .../JobMasterStopWithSavepointITCase.java          |  30 ----
 .../operators/lifecycle/BoundedSourceITCase.java   |   2 +-
 .../lifecycle/PartiallyFinishedSourcesITCase.java  |   2 +-
 .../lifecycle/StopWithSavepointITCase.java         |  11 +-
 .../operators/lifecycle/graph/TestEventSource.java |  13 --
 .../validation/TestJobDataFlowValidator.java       |  81 ++++++---
 .../flink/test/checkpointing/SavepointITCase.java  |  14 +-
 .../test/scheduling/AdaptiveSchedulerITCase.java   |  12 +-
 58 files changed, 439 insertions(+), 735 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
index 73869cb..188a969 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
@@ -18,7 +18,6 @@
 package org.apache.flink.runtime.io;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.runtime.io.network.api.EndOfData;
 
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
@@ -62,17 +61,18 @@ public interface PullingAsyncDataInput<T> extends AvailabilityProvider {
     boolean isFinished();
 
     /**
-     * @return true if the input has seen all data. It might see only control events after that
-     *     point
-     */
-    boolean hasReceivedEndOfData();
-
-    /**
-     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
-     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
-     * should not drain the current task. See also {@code StopMode}.
+     * Tells if we consumed all available data.
      *
-     * <p>We should check the {@link #hasReceivedEndOfData()} first.
+     * <p>Moreover it tells us the reason why there is no more data incoming. If any of the upstream
+     * subtasks finished because of the stop-with-savepoint --no-drain, we should not drain the
+     * input. See also {@code StopMode}.
      */
-    boolean shouldDrainOnEndOfData();
+    EndOfDataStatus hasReceivedEndOfData();
+
+    /** Status for describing if we have reached the end of data. */
+    enum EndOfDataStatus {
+        NOT_END_OF_DATA,
+        DRAINED,
+        STOPPED
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
index 4e244f9..96c2b5e 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
@@ -36,16 +36,16 @@ import java.util.Objects;
  */
 public class EndOfData extends RuntimeEvent {
 
-    private final boolean shouldDrain;
+    private final StopMode mode;
 
     // ------------------------------------------------------------------------
 
-    public EndOfData(boolean shouldDrain) {
-        this.shouldDrain = shouldDrain;
+    public EndOfData(StopMode mode) {
+        this.mode = mode;
     }
 
-    public boolean shouldDrain() {
-        return shouldDrain;
+    public StopMode getStopMode() {
+        return mode;
     }
 
     // ------------------------------------------------------------------------
@@ -77,16 +77,16 @@ public class EndOfData extends RuntimeEvent {
             return false;
         }
         EndOfData endOfData = (EndOfData) o;
-        return shouldDrain == endOfData.shouldDrain;
+        return mode == endOfData.mode;
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(shouldDrain);
+        return Objects.hash(mode);
     }
 
     @Override
     public String toString() {
-        return "EndOfData{shouldDrain=" + shouldDrain + '}';
+        return "EndOfData{mode=" + mode + '}';
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/StopMode.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/StopMode.java
new file mode 100644
index 0000000..4630849
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/StopMode.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.api;
+
+import org.apache.flink.annotation.Internal;
+
+/**
+ * Tells if the job is stopping because of consuming all data. {@link #DRAIN} means the job is
+ * stopping either with a stop-with-savepoint --drain or due to consuming all records in the source.
+ * A drained pipeline should call {@code StreamOperator#finish()} on all operators.
+ */
+@Internal
+public enum StopMode {
+    DRAIN,
+    NO_DRAIN
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
index 46ee08d..ba943a1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent;
 import org.apache.flink.runtime.io.network.api.EventAnnouncement;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -100,7 +101,7 @@ public class EventSerializer {
                         0,
                         0,
                         END_OF_USER_RECORDS_EVENT,
-                        ((EndOfData) event).shouldDrain() ? (byte) 1 : (byte) 0
+                        (byte) ((EndOfData) event).getStopMode().ordinal()
                     });
         } else if (eventClass == CancelCheckpointMarker.class) {
             CancelCheckpointMarker marker = (CancelCheckpointMarker) event;
@@ -164,7 +165,7 @@ public class EventSerializer {
             } else if (type == END_OF_CHANNEL_STATE_EVENT) {
                 return EndOfChannelStateEvent.INSTANCE;
             } else if (type == END_OF_USER_RECORDS_EVENT) {
-                return new EndOfData(buffer.get() == 1);
+                return new EndOfData(StopMode.values()[buffer.get()]);
             } else if (type == CANCEL_CHECKPOINT_MARKER_EVENT) {
                 long id = buffer.getLong();
                 return new CancelCheckpointMarker(id);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index ff0160a..081565f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.io.network.api.writer;
 
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.AvailabilityProvider;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
 import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
@@ -70,10 +71,10 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
      * Notifies the downstream tasks that this {@code ResultPartitionWriter} have emitted all the
      * user records.
      *
-     * @param shouldDrain tells if we should flush all records or not (it is false in case of
+     * @param mode tells if we should flush all records or not (it is false in case of
      *     stop-with-savepoint (--no-drain))
      */
-    void notifyEndOfData(boolean shouldDrain) throws IOException;
+    void notifyEndOfData(StopMode mode) throws IOException;
 
     /**
      * Gets the future indicating whether all the records has been processed by the downstream
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
index 9b6a724..85aaaa5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.util.function.SupplierWithException;
@@ -66,9 +67,9 @@ public class BoundedBlockingResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {
+    public void notifyEndOfData(StopMode mode) throws IOException {
         if (!hasNotifiedEndOfUserRecords) {
-            broadcastEvent(new EndOfData(shouldDrain), false);
+            broadcastEvent(new EndOfData(mode), false);
             hasNotifiedEndOfUserRecords = true;
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
index dd00bb1..f25a54a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
 import org.apache.flink.util.function.SupplierWithException;
@@ -193,10 +194,10 @@ public class PipelinedResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {
+    public void notifyEndOfData(StopMode mode) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(new EndOfData(shouldDrain), false);
+                broadcastEvent(new EndOfData(mode), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 87561ec..53aec93 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -23,6 +23,7 @@ import org.apache.flink.metrics.Counter;
 import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
@@ -200,7 +201,7 @@ public abstract class ResultPartition implements ResultPartitionWriter {
     // ------------------------------------------------------------------------
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {
+    public void notifyEndOfData(StopMode mode) throws IOException {
         throw new UnsupportedOperationException();
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index a8e8d31..d53c8ec 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.disk.BatchShuffleReadBufferPool;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
@@ -398,10 +399,10 @@ public class SortMergeResultPartition extends ResultPartition {
     }
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {
+    public void notifyEndOfData(StopMode mode) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(new EndOfData(shouldDrain), false);
+                broadcastEvent(new EndOfData(mode), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
index 012aec6..1934153 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
@@ -22,7 +22,6 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.PullingAsyncDataInput;
-import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.partition.ChannelStateHolder;
 
 import java.io.IOException;
@@ -99,17 +98,6 @@ public abstract class InputGate
 
     public abstract boolean isFinished();
 
-    public abstract boolean hasReceivedEndOfData();
-
-    /**
-     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
-     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
-     * should not drain the current task. See also {@code StopMode}.
-     *
-     * <p>We should check the {@link #hasReceivedEndOfData()} first.
-     */
-    public abstract boolean shouldDrainOnEndOfData();
-
     /**
      * Blocking call waiting for next {@link BufferOrEvent}.
      *
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index ee78ff6..540c036 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -29,6 +29,7 @@ import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
@@ -667,13 +668,14 @@ public class SingleInputGate extends IndexedInputGate {
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
-        return hasReceivedEndOfData;
-    }
-
-    @Override
-    public boolean shouldDrainOnEndOfData() {
-        return shouldDrainOnEndOfData;
+    public EndOfDataStatus hasReceivedEndOfData() {
+        if (!hasReceivedEndOfData) {
+            return EndOfDataStatus.NOT_END_OF_DATA;
+        } else if (shouldDrainOnEndOfData) {
+            return EndOfDataStatus.DRAINED;
+        } else {
+            return EndOfDataStatus.STOPPED;
+        }
     }
 
     @Override
@@ -855,7 +857,7 @@ public class SingleInputGate extends IndexedInputGate {
                 channelsWithEndOfUserRecords.set(currentChannel.getChannelIndex());
                 hasReceivedEndOfData =
                         channelsWithEndOfUserRecords.cardinality() == numberOfInputChannels;
-                shouldDrainOnEndOfData &= ((EndOfData) event).shouldDrain();
+                shouldDrainOnEndOfData &= ((EndOfData) event).getStopMode() == StopMode.DRAIN;
             }
         }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index 4154615..415e1bf 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -181,13 +181,14 @@ public class UnionInputGate extends InputGate {
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
-        return inputGatesWithRemainingUserData.isEmpty();
-    }
-
-    @Override
-    public boolean shouldDrainOnEndOfData() {
-        return shouldDrainOnEndOfData;
+    public EndOfDataStatus hasReceivedEndOfData() {
+        if (!inputGatesWithRemainingUserData.isEmpty()) {
+            return EndOfDataStatus.NOT_END_OF_DATA;
+        } else if (shouldDrainOnEndOfData) {
+            return EndOfDataStatus.DRAINED;
+        } else {
+            return EndOfDataStatus.STOPPED;
+        }
     }
 
     @Override
@@ -293,9 +294,9 @@ public class UnionInputGate extends InputGate {
     private void handleEndOfUserDataEvent(BufferOrEvent bufferOrEvent, InputGate inputGate) {
         if (bufferOrEvent.isEvent()
                 && bufferOrEvent.getEvent().getClass() == EndOfData.class
-                && inputGate.hasReceivedEndOfData()) {
+                && inputGate.hasReceivedEndOfData() != EndOfDataStatus.NOT_END_OF_DATA) {
 
-            shouldDrainOnEndOfData &= inputGate.shouldDrainOnEndOfData();
+            shouldDrainOnEndOfData &= inputGate.hasReceivedEndOfData() == EndOfDataStatus.DRAINED;
             if (!inputGatesWithRemainingUserData.remove(inputGate)) {
                 throw new IllegalStateException(
                         "Couldn't find input gate in set of remaining input gates.");
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
index a7d425c..25f585b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.taskmanager;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.BufferAvailabilityListener;
 import org.apache.flink.runtime.io.network.partition.CheckpointedResultPartition;
@@ -149,8 +150,8 @@ public class ConsumableNotifyingResultPartitionWriterDecorator {
         }
 
         @Override
-        public void notifyEndOfData(boolean shouldDrain) throws IOException {
-            partitionWriter.notifyEndOfData(shouldDrain);
+        public void notifyEndOfData(StopMode mode) throws IOException {
+            partitionWriter.notifyEndOfData(mode);
         }
 
         @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
index dae257d..e3f6cfd 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
@@ -96,16 +96,11 @@ public class InputGateWithMetrics extends IndexedInputGate {
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
+    public EndOfDataStatus hasReceivedEndOfData() {
         return inputGate.hasReceivedEndOfData();
     }
 
     @Override
-    public boolean shouldDrainOnEndOfData() {
-        return inputGate.shouldDrainOnEndOfData();
-    }
-
-    @Override
     public void setup() throws IOException {
         inputGate.setup();
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
index 6a2ef20..67f2d9a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.EndOfSuperstepEvent;
 import org.apache.flink.runtime.io.network.api.EventAnnouncement;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.SubtaskConnectionDescriptor;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -49,8 +50,8 @@ public class EventSerializerTest {
     private final AbstractEvent[] events = {
         EndOfPartitionEvent.INSTANCE,
         EndOfSuperstepEvent.INSTANCE,
-        new EndOfData(true),
-        new EndOfData(false),
+        new EndOfData(StopMode.DRAIN),
+        new EndOfData(StopMode.NO_DRAIN),
         new CheckpointBarrier(
                 1678L,
                 4623784L,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
index bd2b667..a192090f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.io.network.api.writer;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.serialization.RecordDeserializer;
 import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
@@ -89,9 +90,9 @@ public class RecordOrEventCollectingResultPartitionWriter<T>
     }
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {
+    public void notifyEndOfData(StopMode mode) throws IOException {
         if (collectNetworkEvents) {
-            broadcastEvent(new EndOfData(shouldDrain), false);
+            broadcastEvent(new EndOfData(mode), false);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
index 9e1973f..67cf025 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.runtime.io.network.netty;
 
 import org.apache.flink.runtime.io.network.NetworkSequenceViewReader;
 import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.netty.NettyMessage.ErrorResponse;
 import org.apache.flink.runtime.io.network.netty.NettyMessage.PartitionRequest;
 import org.apache.flink.runtime.io.network.netty.NettyMessage.ResumeConsumption;
@@ -128,7 +129,7 @@ public class PartitionRequestServerHandlerTest extends TestLogger {
         partitionRequestQueue.notifyReaderCreated(viewReader);
 
         // Write the message to acknowledge all records are processed to server
-        resultPartition.notifyEndOfData(true);
+        resultPartition.notifyEndOfData(StopMode.DRAIN);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 resultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
index 855ff82..4e519f8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.testutils.CheckedThread;
 import org.apache.flink.runtime.io.disk.FileChannelManager;
 import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -256,7 +257,7 @@ public class BoundedBlockingSubpartitionWriteReadTest {
 
     private void writeEndOfData(BoundedBlockingSubpartition subpartition) throws IOException {
         try (BufferConsumer eventBufferConsumer =
-                EventSerializer.toBufferConsumer(new EndOfData(true), false)) {
+                EventSerializer.toBufferConsumer(new EndOfData(StopMode.DRAIN), false)) {
             // Retain the buffer so that it can be recycled by each channel of targetPartition
             subpartition.add(eventBufferConsumer.copy(), 0);
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
index 6db08bd..4a6be76 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.runtime.event.AbstractEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 
@@ -61,7 +62,7 @@ public class MockResultPartitionWriter implements ResultPartitionWriter {
     public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException {}
 
     @Override
-    public void notifyEndOfData(boolean shouldDrain) throws IOException {}
+    public void notifyEndOfData(StopMode mode) throws IOException {}
 
     @Override
     public CompletableFuture<Void> getAllDataProcessedFuture() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index 06ac632..3bbd658 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -625,7 +626,7 @@ public class ResultPartitionTest {
         BufferWritingResultPartition bufferWritingResultPartition =
                 createResultPartition(ResultPartitionType.PIPELINED_BOUNDED);
 
-        bufferWritingResultPartition.notifyEndOfData(true);
+        bufferWritingResultPartition.notifyEndOfData(StopMode.DRAIN);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 bufferWritingResultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
@@ -634,7 +635,7 @@ public class ResultPartitionTest {
             Buffer nextBuffer = ((PipelinedSubpartition) resultSubpartition).pollBuffer().buffer();
             assertFalse(nextBuffer.isBuffer());
             assertEquals(
-                    new EndOfData(true),
+                    new EndOfData(StopMode.DRAIN),
                     EventSerializer.fromBuffer(nextBuffer, getClass().getClassLoader()));
         }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 1ea926ac..abf1d58 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.PullingAsyncDataInput;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.TaskEventDispatcher;
@@ -35,6 +36,7 @@ import org.apache.flink.runtime.io.network.TaskEventPublisher;
 import org.apache.flink.runtime.io.network.TestingConnectionManager;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
 import org.apache.flink.runtime.io.network.buffer.BufferDecompressor;
@@ -238,15 +240,19 @@ public class SingleInputGateTest extends InputGateTestBase {
         verifyBufferOrEvent(inputGate, true, 0, true);
         verifyBufferOrEvent(inputGate, false, 1, true);
         // we have received EndOfData on a single channel only
-        assertFalse(inputGate.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
+                inputGate.hasReceivedEndOfData());
         verifyBufferOrEvent(inputGate, false, 0, true);
         assertFalse(inputGate.isFinished());
-        assertTrue(inputGate.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate.hasReceivedEndOfData());
         verifyBufferOrEvent(inputGate, false, 1, true);
         verifyBufferOrEvent(inputGate, false, 0, false);
 
         // Return null when the input gate has received all end-of-partition events
-        assertTrue(inputGate.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate.hasReceivedEndOfData());
         assertTrue(inputGate.isFinished());
 
         for (TestInputChannel ic : inputChannels) {
@@ -272,11 +278,11 @@ public class SingleInputGateTest extends InputGateTestBase {
         inputGate2.setInputChannels(inputChannels2);
 
         // Test
-        inputChannels1[1].readEndOfData(true);
-        inputChannels1[0].readEndOfData(false);
+        inputChannels1[1].readEndOfData(StopMode.DRAIN);
+        inputChannels1[0].readEndOfData(StopMode.NO_DRAIN);
 
-        inputChannels2[1].readEndOfData(true);
-        inputChannels2[0].readEndOfData(true);
+        inputChannels2[1].readEndOfData(StopMode.DRAIN);
+        inputChannels2[0].readEndOfData(StopMode.DRAIN);
 
         inputGate1.notifyChannelNonEmpty(inputChannels1[0]);
         inputGate1.notifyChannelNonEmpty(inputChannels1[1]);
@@ -285,19 +291,23 @@ public class SingleInputGateTest extends InputGateTestBase {
 
         verifyBufferOrEvent(inputGate1, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertFalse(inputGate1.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
+                inputGate1.hasReceivedEndOfData());
         verifyBufferOrEvent(inputGate1, false, 1, true);
-        assertTrue(inputGate1.hasReceivedEndOfData());
         // one of the channels said we should not drain
-        assertFalse(inputGate1.shouldDrainOnEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.STOPPED, inputGate1.hasReceivedEndOfData());
 
         verifyBufferOrEvent(inputGate2, false, 0, true);
         // we have received EndOfData on a single channel only
-        assertFalse(inputGate2.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
+                inputGate2.hasReceivedEndOfData());
         verifyBufferOrEvent(inputGate2, false, 1, true);
-        assertTrue(inputGate2.hasReceivedEndOfData());
         // both channels said we should drain
-        assertTrue(inputGate2.shouldDrainOnEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.DRAINED, inputGate2.hasReceivedEndOfData());
     }
 
     /**
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
index 3011d3c..c891a73 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
@@ -22,6 +22,7 @@ import org.apache.flink.metrics.SimpleCounter;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
@@ -107,13 +108,13 @@ public class TestInputChannel extends InputChannel {
     }
 
     TestInputChannel readEndOfData() throws IOException {
-        return readEndOfData(true);
+        return readEndOfData(StopMode.DRAIN);
     }
 
-    TestInputChannel readEndOfData(boolean shouldDrain) throws IOException {
+    TestInputChannel readEndOfData(StopMode mode) throws IOException {
         addBufferAndAvailability(
                 new BufferAndAvailability(
-                        EventSerializer.toBuffer(new EndOfData(shouldDrain), false),
+                        EventSerializer.toBuffer(new EndOfData(mode), false),
                         Buffer.DataType.EVENT_BUFFER,
                         0,
                         sequenceNumber++));
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index c9ca40e..de507e2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -19,6 +19,8 @@
 package org.apache.flink.runtime.io.network.partition.consumer;
 
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
+import org.apache.flink.runtime.io.PullingAsyncDataInput;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
 import org.apache.flink.runtime.io.network.partition.NoOpResultSubpartitionView;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
@@ -114,12 +116,14 @@ public class UnionInputGateTest extends InputGateTestBase {
         verifyBufferOrEvent(union, false, 0, true); // gate 1, channel 0
         verifyBufferOrEvent(union, false, 4, true); // gate 1, channel 1
         verifyBufferOrEvent(union, false, 1, true); // gate 1, channel 1
-        assertFalse(union.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
+                union.hasReceivedEndOfData());
         verifyBufferOrEvent(union, false, 5, true); // gate 2, channel 2
         verifyBufferOrEvent(union, false, 2, true); // gate 1, channel 2
         verifyBufferOrEvent(union, false, 6, true); // gate 2, channel 3
         verifyBufferOrEvent(union, false, 7, true); // gate 2, channel 4
-        assertTrue(union.hasReceivedEndOfData());
+        assertEquals(PullingAsyncDataInput.EndOfDataStatus.DRAINED, union.hasReceivedEndOfData());
         assertFalse(union.isFinished());
         verifyBufferOrEvent(union, false, 3, true); // gate 2, channel 0
         verifyBufferOrEvent(union, false, 4, true); // gate 2, channel 1
@@ -150,11 +154,11 @@ public class UnionInputGateTest extends InputGateTestBase {
         inputGate2.setInputChannels(inputChannels2);
 
         // Test
-        inputChannels1[1].readEndOfData(true);
-        inputChannels1[0].readEndOfData(false);
+        inputChannels1[1].readEndOfData(StopMode.DRAIN);
+        inputChannels1[0].readEndOfData(StopMode.NO_DRAIN);
 
-        inputChannels2[1].readEndOfData(true);
-        inputChannels2[0].readEndOfData(true);
+        inputChannels2[1].readEndOfData(StopMode.DRAIN);
+        inputChannels2[0].readEndOfData(StopMode.DRAIN);
 
         final UnionInputGate unionInputGate = new UnionInputGate(inputGate1, inputGate2);
 
@@ -166,13 +170,16 @@ public class UnionInputGateTest extends InputGateTestBase {
         verifyBufferOrEvent(unionInputGate, false, 0, true);
         verifyBufferOrEvent(unionInputGate, false, 2, true);
         // we have received EndOfData on a single input only
-        assertFalse(unionInputGate.hasReceivedEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.NOT_END_OF_DATA,
+                unionInputGate.hasReceivedEndOfData());
 
         verifyBufferOrEvent(unionInputGate, false, 1, true);
         verifyBufferOrEvent(unionInputGate, false, 3, true);
         // both channels received EndOfData, one channel said we should not drain
-        assertTrue(unionInputGate.hasReceivedEndOfData());
-        assertFalse(unionInputGate.shouldDrainOnEndOfData());
+        assertEquals(
+                PullingAsyncDataInput.EndOfDataStatus.STOPPED,
+                unionInputGate.hasReceivedEndOfData());
     }
 
     @Test
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/SourceOperator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/SourceOperator.java
index eb82768..9aed32d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/SourceOperator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/operators/SourceOperator.java
@@ -34,6 +34,7 @@ import org.apache.flink.core.io.InputStatus;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.metrics.groups.SourceReaderMetricGroup;
 import org.apache.flink.runtime.io.AvailabilityProvider;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.metrics.groups.InternalSourceReaderMetricGroup;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
@@ -146,6 +147,7 @@ public class SourceOperator<OUT, SplitT extends SourceSplit> extends AbstractStr
     private enum OperatingMode {
         READING,
         OUTPUT_NOT_INITIALIZED,
+        SOURCE_DRAINED,
         SOURCE_STOPPED,
         DATA_FINISHED
     }
@@ -305,23 +307,35 @@ public class SourceOperator<OUT, SplitT extends SourceSplit> extends AbstractStr
 
     @Override
     public void finish() throws Exception {
+        stopInternalServices();
+        super.finish();
+
+        finished.complete(null);
+    }
+
+    private void stopInternalServices() {
         if (eventTimeLogic != null) {
             eventTimeLogic.stopPeriodicWatermarkEmits();
         }
         if (latencyMarerEmitter != null) {
             latencyMarerEmitter.close();
         }
-        super.finish();
-
-        finished.complete(null);
     }
 
-    public CompletableFuture<Void> stop() {
+    public CompletableFuture<Void> stop(StopMode mode) {
         switch (operatingMode) {
             case OUTPUT_NOT_INITIALIZED:
             case READING:
-                this.operatingMode = OperatingMode.SOURCE_STOPPED;
+                this.operatingMode =
+                        mode == StopMode.DRAIN
+                                ? OperatingMode.SOURCE_DRAINED
+                                : OperatingMode.SOURCE_STOPPED;
                 availabilityHelper.forceStop();
+                if (this.operatingMode == OperatingMode.SOURCE_STOPPED) {
+                    stopInternalServices();
+                    finished.complete(null);
+                    return finished;
+                }
                 break;
         }
         return finished;
@@ -363,6 +377,10 @@ public class SourceOperator<OUT, SplitT extends SourceSplit> extends AbstractStr
             case SOURCE_STOPPED:
                 this.operatingMode = OperatingMode.DATA_FINISHED;
                 sourceMetricGroup.idlingStarted();
+                return DataInputStatus.STOPPED;
+            case SOURCE_DRAINED:
+                this.operatingMode = OperatingMode.DATA_FINISHED;
+                sourceMetricGroup.idlingStarted();
                 return DataInputStatus.END_OF_DATA;
             case DATA_FINISHED:
                 sourceMetricGroup.idlingStarted();
@@ -423,6 +441,7 @@ public class SourceOperator<OUT, SplitT extends SourceSplit> extends AbstractStr
             case READING:
                 return availabilityHelper.update(sourceReader.isAvailable());
             case SOURCE_STOPPED:
+            case SOURCE_DRAINED:
             case DATA_FINISHED:
                 return AvailabilityProvider.AVAILABLE;
             default:
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
index d60c824..6c59c99 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
@@ -151,10 +151,14 @@ public abstract class AbstractStreamTaskNetworkInput<
         // Event received
         final AbstractEvent event = bufferOrEvent.getEvent();
         if (event.getClass() == EndOfData.class) {
-            if (checkpointedInputGate.hasReceivedEndOfData()) {
-                return checkpointedInputGate.shouldDrainOnEndOfData()
-                        ? DataInputStatus.END_OF_DATA
-                        : DataInputStatus.STOPPED;
+            switch (checkpointedInputGate.hasReceivedEndOfData()) {
+                case NOT_END_OF_DATA:
+                    // skip
+                    break;
+                case DRAINED:
+                    return DataInputStatus.END_OF_DATA;
+                case STOPPED:
+                    return DataInputStatus.STOPPED;
             }
         } else if (event.getClass() == EndOfPartitionEvent.class) {
             // release the record deserializer immediately,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
index b497e4d..a683c74 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
@@ -223,15 +223,10 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
+    public EndOfDataStatus hasReceivedEndOfData() {
         return inputGate.hasReceivedEndOfData();
     }
 
-    @Override
-    public boolean shouldDrainOnEndOfData() {
-        return inputGate.shouldDrainOnEndOfData();
-    }
-
     /**
      * Cleans up all internally held resources.
      *
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/FinishedOperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/FinishedOperatorChain.java
index f22086f..1f9be0b 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/FinishedOperatorChain.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/FinishedOperatorChain.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
@@ -76,7 +77,8 @@ public class FinishedOperatorChain<OUT, OP extends StreamOperator<OUT>>
             StreamTaskStateInitializer streamTaskStateInitializer) {}
 
     @Override
-    public void finishOperators(StreamTaskActionExecutor actionExecutor) throws Exception {}
+    public void finishOperators(StreamTaskActionExecutor actionExecutor, StopMode stopMode)
+            throws Exception {}
 
     @Override
     public void notifyCheckpointComplete(long checkpointId) throws Exception {}
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
index 4f79660..fda39e2 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTask.java
@@ -25,6 +25,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
 import org.apache.flink.runtime.metrics.MetricNames;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -195,8 +196,8 @@ public class MultipleInputStreamTask<OUT>
         // EndOfPartitionEvent, we would not complement barriers for the
         // unfinished network inputs, and the checkpoint would be triggered
         // after received all the EndOfPartitionEvent.
-        if (options.getCheckpointType().shouldDrain()) {
-            return triggerStopWithSavepointWithDrainAsync(metadata, options);
+        if (options.getCheckpointType().isSynchronous()) {
+            return triggerStopWithSavepointAsync(metadata, options);
         } else {
             return triggerSourcesCheckpointAsync(metadata, options);
         }
@@ -235,17 +236,21 @@ public class MultipleInputStreamTask<OUT>
         return resultFuture;
     }
 
-    private CompletableFuture<Boolean> triggerStopWithSavepointWithDrainAsync(
+    private CompletableFuture<Boolean> triggerStopWithSavepointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
 
         CompletableFuture<Void> sourcesStopped = new CompletableFuture<>();
+        final StopMode stopMode =
+                checkpointOptions.getCheckpointType().shouldDrain()
+                        ? StopMode.DRAIN
+                        : StopMode.NO_DRAIN;
         mainMailboxExecutor.execute(
                 () -> {
-                    setSynchronousSavepoint(checkpointMetaData.getCheckpointId(), true);
+                    setSynchronousSavepoint(checkpointMetaData.getCheckpointId());
                     FutureUtils.forward(
                             FutureUtils.waitForAll(
                                     operatorChain.getSourceTaskInputs().stream()
-                                            .map(s -> s.getOperator().stop())
+                                            .map(s -> s.getOperator().stop(stopMode))
                                             .collect(Collectors.toList())),
                             sourcesStopped);
                 },
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
index a009452..f9f2e1c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/OperatorChain.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
@@ -290,7 +291,7 @@ public abstract class OperatorChain<OUT, OP extends StreamOperator<OUT>>
     /**
      * Initialize state and open all operators in the chain from <b>tail to heads</b>, contrary to
      * {@link StreamOperator#close()} which happens <b>heads to tail</b> (see {@link
-     * #finishOperators(StreamTaskActionExecutor)}).
+     * #finishOperators(StreamTaskActionExecutor, StopMode)}).
      */
     public abstract void initializeStateAndOpenOperators(
             StreamTaskStateInitializer streamTaskStateInitializer) throws Exception;
@@ -300,7 +301,8 @@ public abstract class OperatorChain<OUT, OP extends StreamOperator<OUT>>
      * operator in the chain, contrary to {@link StreamOperator#open()} which happens <b>tail to
      * heads</b> (see {@link #initializeStateAndOpenOperators(StreamTaskStateInitializer)}).
      */
-    public abstract void finishOperators(StreamTaskActionExecutor actionExecutor) throws Exception;
+    public abstract void finishOperators(StreamTaskActionExecutor actionExecutor, StopMode stopMode)
+            throws Exception;
 
     public abstract void notifyCheckpointComplete(long checkpointId) throws Exception;
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/RegularOperatorChain.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/RegularOperatorChain.java
index 7a267a7..078588e 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/RegularOperatorChain.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/RegularOperatorChain.java
@@ -23,6 +23,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriterDelegate;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
@@ -112,9 +113,10 @@ public class RegularOperatorChain<OUT, OP extends StreamOperator<OUT>>
     }
 
     @Override
-    public void finishOperators(StreamTaskActionExecutor actionExecutor) throws Exception {
+    public void finishOperators(StreamTaskActionExecutor actionExecutor, StopMode stopMode)
+            throws Exception {
         if (firstOperatorWrapper != null) {
-            firstOperatorWrapper.finish(actionExecutor);
+            firstOperatorWrapper.finish(actionExecutor, stopMode);
         }
     }
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
index 6ce7b26..b6b8c6c 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTask.java
@@ -26,6 +26,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.metrics.MetricNames;
 import org.apache.flink.runtime.metrics.groups.InternalSourceReaderMetricGroup;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
@@ -115,9 +116,8 @@ public class SourceOperatorStreamTask<T> extends StreamTask<T, SourceOperator<T,
     public CompletableFuture<Boolean> triggerCheckpointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
         if (!isExternallyInducedSource) {
-            if (checkpointOptions.getCheckpointType().shouldDrain()) {
-                return triggerStopWithSavepointWithDrainAsync(
-                        checkpointMetaData, checkpointOptions);
+            if (checkpointOptions.getCheckpointType().isSynchronous()) {
+                return triggerStopWithSavepointAsync(checkpointMetaData, checkpointOptions);
             } else {
                 return super.triggerCheckpointAsync(checkpointMetaData, checkpointOptions);
             }
@@ -126,16 +126,21 @@ public class SourceOperatorStreamTask<T> extends StreamTask<T, SourceOperator<T,
         }
     }
 
-    private CompletableFuture<Boolean> triggerStopWithSavepointWithDrainAsync(
+    private CompletableFuture<Boolean> triggerStopWithSavepointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
 
         CompletableFuture<Void> operatorFinished = new CompletableFuture<>();
         mainMailboxExecutor.execute(
                 () -> {
-                    setSynchronousSavepoint(checkpointMetaData.getCheckpointId(), true);
-                    FutureUtils.forward(mainOperator.stop(), operatorFinished);
+                    setSynchronousSavepoint(checkpointMetaData.getCheckpointId());
+                    FutureUtils.forward(
+                            mainOperator.stop(
+                                    checkpointOptions.getCheckpointType().shouldDrain()
+                                            ? StopMode.DRAIN
+                                            : StopMode.NO_DRAIN),
+                            operatorFinished);
                 },
-                "stop Flip-27 source for stop-with-savepoint --drain");
+                "stop Flip-27 source for stop-with-savepoint");
 
         return operatorFinished.thenCompose(
                 (ignore) -> super.triggerCheckpointAsync(checkpointMetaData, checkpointOptions));
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
index 0206ef8..1a94c19 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTask.java
@@ -24,6 +24,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.metrics.MetricNames;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.streaming.api.checkpoint.ExternallyInducedSource;
@@ -69,18 +70,18 @@ public class SourceStreamTask<
     private final AtomicBoolean stopped = new AtomicBoolean(false);
 
     private enum FinishingReason {
-        END_OF_DATA(true),
-        STOP_WITH_SAVEPOINT_DRAIN(true),
-        STOP_WITH_SAVEPOINT_NO_DRAIN(false);
+        END_OF_DATA(StopMode.DRAIN),
+        STOP_WITH_SAVEPOINT_DRAIN(StopMode.DRAIN),
+        STOP_WITH_SAVEPOINT_NO_DRAIN(StopMode.NO_DRAIN);
 
-        private final boolean shouldCallFinish;
+        private final StopMode stopMode;
 
-        FinishingReason(boolean shouldCallFinish) {
-            this.shouldCallFinish = shouldCallFinish;
+        FinishingReason(StopMode stopMode) {
+            this.stopMode = stopMode;
         }
 
-        boolean shouldCallFinish() {
-            return this.shouldCallFinish;
+        StopMode toStopMode() {
+            return this.stopMode;
         }
     }
 
@@ -265,9 +266,8 @@ public class SourceStreamTask<
     public CompletableFuture<Boolean> triggerCheckpointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
         if (!externallyInducedCheckpoints) {
-            if (checkpointOptions.getCheckpointType().shouldDrain()) {
-                return triggerStopWithSavepointWithDrainAsync(
-                        checkpointMetaData, checkpointOptions);
+            if (checkpointOptions.getCheckpointType().isSynchronous()) {
+                return triggerStopWithSavepointAsync(checkpointMetaData, checkpointOptions);
             } else {
                 return super.triggerCheckpointAsync(checkpointMetaData, checkpointOptions);
             }
@@ -279,12 +279,13 @@ public class SourceStreamTask<
         }
     }
 
-    private CompletableFuture<Boolean> triggerStopWithSavepointWithDrainAsync(
+    private CompletableFuture<Boolean> triggerStopWithSavepointAsync(
             CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {
         mainMailboxExecutor.execute(
                 () ->
-                        stopOperatorForStopWithSavepointWithDrain(
-                                checkpointMetaData.getCheckpointId()),
+                        stopOperatorForStopWithSavepoint(
+                                checkpointMetaData.getCheckpointId(),
+                                checkpointOptions.getCheckpointType().shouldDrain()),
                 "stop legacy source for stop-with-savepoint --drain");
         return sourceThread
                 .getCompletionFuture()
@@ -294,9 +295,12 @@ public class SourceStreamTask<
                                         checkpointMetaData, checkpointOptions));
     }
 
-    private void stopOperatorForStopWithSavepointWithDrain(long checkpointId) {
-        setSynchronousSavepoint(checkpointId, true);
-        finishingReason = FinishingReason.STOP_WITH_SAVEPOINT_DRAIN;
+    private void stopOperatorForStopWithSavepoint(long checkpointId, boolean drain) {
+        setSynchronousSavepoint(checkpointId);
+        finishingReason =
+                drain
+                        ? FinishingReason.STOP_WITH_SAVEPOINT_DRAIN
+                        : FinishingReason.STOP_WITH_SAVEPOINT_NO_DRAIN;
         if (mainOperator != null) {
             mainOperator.stop();
         }
@@ -335,9 +339,6 @@ public class SourceStreamTask<
                         && ExceptionUtils.findThrowable(t, InterruptedException.class)
                                 .isPresent()) {
                     completionFuture.completeExceptionally(new CancelTaskException(t));
-                } else if (finishingReason == FinishingReason.STOP_WITH_SAVEPOINT_NO_DRAIN) {
-                    // swallow all exceptions if the source was stopped without drain
-                    completionFuture.complete(null);
                 } else {
                     completionFuture.completeExceptionally(t);
                 }
@@ -345,15 +346,17 @@ public class SourceStreamTask<
         }
 
         private void completeProcessing() throws InterruptedException, ExecutionException {
-            if (finishingReason.shouldCallFinish() && !isCanceled() && !isFailing()) {
+            if (!isCanceled() && !isFailing()) {
                 mainMailboxExecutor
                         .submit(
                                 () -> {
                                     // theoretically the StreamSource can implement BoundedOneInput,
-                                    // so we
-                                    // need to call it here
-                                    operatorChain.endInput(1);
-                                    endData();
+                                    // so we need to call it here
+                                    final StopMode stopMode = finishingReason.toStopMode();
+                                    if (stopMode == StopMode.DRAIN) {
+                                        operatorChain.endInput(1);
+                                    }
+                                    endData(stopMode);
                                 },
                                 "SourceStreamTask finished processing data.")
                         .get();
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapper.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapper.java
index e8a24ae..14c296e 100755
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapper.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapper.java
@@ -19,6 +19,7 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.operators.MailboxExecutor;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.streaming.api.operators.BoundedMultiInput;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.StreamOperator;
@@ -39,8 +40,8 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * This class handles the finish, endInput and other related logic of a {@link StreamOperator}. It
  * also automatically propagates the finish operation to the next wrapper that the {@link #next}
  * points to, so we can use {@link #next} to link all operator wrappers in the operator chain and
- * finish all operators only by calling the {@link #finish(StreamTaskActionExecutor)} method of the
- * header operator wrapper.
+ * finish all operators only by calling the {@link #finish(StreamTaskActionExecutor, StopMode)}
+ * method of the header operator wrapper.
  */
 @Internal
 public class StreamOperatorWrapper<OUT, OP extends StreamOperator<OUT>> {
@@ -120,18 +121,19 @@ public class StreamOperatorWrapper<OUT, OP extends StreamOperator<OUT>> {
      * MailboxExecutor#yield()} to take the mails of closing operator and running timers and run
      * them.
      */
-    public void finish(StreamTaskActionExecutor actionExecutor) throws Exception {
-        if (!isHead) {
+    public void finish(StreamTaskActionExecutor actionExecutor, StopMode stopMode)
+            throws Exception {
+        if (!isHead && stopMode == StopMode.DRAIN) {
             // NOTE: This only do for the case where the operator is one-input operator. At present,
             // any non-head operator on the operator chain is one-input operator.
             actionExecutor.runThrowing(() -> endOperatorInput(1));
         }
 
-        quiesceTimeServiceAndFinishOperator(actionExecutor);
+        quiesceTimeServiceAndFinishOperator(actionExecutor, stopMode);
 
         // propagate the close operation to the next wrapper
         if (next != null) {
-            next.finish(actionExecutor);
+            next.finish(actionExecutor, stopMode);
         }
     }
 
@@ -141,7 +143,8 @@ public class StreamOperatorWrapper<OUT, OP extends StreamOperator<OUT>> {
         wrapped.close();
     }
 
-    private void quiesceTimeServiceAndFinishOperator(StreamTaskActionExecutor actionExecutor)
+    private void quiesceTimeServiceAndFinishOperator(
+            StreamTaskActionExecutor actionExecutor, StopMode stopMode)
             throws InterruptedException, ExecutionException {
 
         // step 1. to ensure that there is no longer output triggered by the timers before invoking
@@ -154,7 +157,8 @@ public class StreamOperatorWrapper<OUT, OP extends StreamOperator<OUT>> {
         // in the mailbox are completed before exiting the following mailbox processing loop
         CompletableFuture<Void> finishedFuture =
                 quiesceProcessingTimeService()
-                        .thenCompose(unused -> deferFinishOperatorToMailbox(actionExecutor))
+                        .thenCompose(
+                                unused -> deferFinishOperatorToMailbox(actionExecutor, stopMode))
                         .thenCompose(unused -> sendFinishedMail());
 
         // run the mailbox processing loop until all operations are finished
@@ -176,7 +180,11 @@ public class StreamOperatorWrapper<OUT, OP extends StreamOperator<OUT>> {
     }
 
     private CompletableFuture<Void> deferFinishOperatorToMailbox(
-            StreamTaskActionExecutor actionExecutor) {
+            StreamTaskActionExecutor actionExecutor, StopMode stopMode) {
+        if (stopMode == StopMode.NO_DRAIN) {
+            return CompletableFuture.completedFuture(null);
+        }
+
         final CompletableFuture<Void> finishOperatorFuture = new CompletableFuture<>();
 
         mailboxExecutor.execute(
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 42a1c2b..558a60f 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.AvailabilityProvider;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.MultipleRecordWriters;
 import org.apache.flink.runtime.io.network.api.writer.NonRecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
@@ -277,8 +278,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
     // ========================================================
     //  Final  checkpoint / savepoint
     // ========================================================
-    private Long syncSavepointWithoutDrain = null;
-    private Long syncSavepointWithDrain = null;
+    private Long syncSavepoint = null;
     private Long finalCheckpointMinId = null;
     private final CompletableFuture<Void> finalCheckpointCompleted = new CompletableFuture<>();
 
@@ -526,9 +526,10 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
             case END_OF_RECOVERY:
                 throw new IllegalStateException("We should not receive this event here.");
             case STOPPED:
-                throw new UnsupportedOperationException("Not supported yet");
+                endData(StopMode.NO_DRAIN);
+                return;
             case END_OF_DATA:
-                endData();
+                endData(StopMode.DRAIN);
                 return;
             case END_OF_INPUT:
                 // Suspend the mailbox processor, it would be resumed in afterInvoke and finished
@@ -560,71 +561,40 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                         new ResumeWrapper(controller.suspendDefaultAction(timer), timer)));
     }
 
-    protected void endData() throws Exception {
+    protected void endData(StopMode mode) throws Exception {
 
-        if (syncSavepointWithoutDrain != null && areCheckpointsWithFinishedTasksEnabled()) {
-            throw new FlinkRuntimeException(
-                    "We run out of data to process while waiting for a synchronous savepoint"
-                            + " to be finished. This can lead to a deadlock waiting for a final"
-                            + " checkpoint after a synchronous savepoint, which will never be"
-                            + " triggered.");
+        if (mode == StopMode.DRAIN) {
+            advanceToEndOfEventTime();
         }
-
-        advanceToEndOfEventTime();
         // finish all operators in a chain effect way
-        operatorChain.finishOperators(actionExecutor);
+        operatorChain.finishOperators(actionExecutor, mode);
         this.finishedOperators = true;
 
         for (ResultPartitionWriter partitionWriter : getEnvironment().getAllWriters()) {
-            partitionWriter.notifyEndOfData(true);
+            partitionWriter.notifyEndOfData(mode);
         }
 
         this.endOfDataReceived = true;
     }
 
-    protected void setSynchronousSavepoint(long checkpointId, boolean isDrain) {
+    protected void setSynchronousSavepoint(long checkpointId) {
         checkState(
-                syncSavepointWithoutDrain == null
-                        && (syncSavepointWithDrain == null
-                                || (isDrain && syncSavepointWithDrain == checkpointId)),
+                syncSavepoint == null || syncSavepoint == checkpointId,
                 "at most one stop-with-savepoint checkpoint at a time is allowed");
-        if (isDrain) {
-            if (syncSavepointWithDrain == null) {
-                syncSavepointWithDrain = checkpointId;
-            }
-        } else {
-            syncSavepointWithoutDrain = checkpointId;
-        }
+        syncSavepoint = checkpointId;
     }
 
     @VisibleForTesting
     OptionalLong getSynchronousSavepointId() {
-        if (syncSavepointWithoutDrain != null) {
-            return OptionalLong.of(syncSavepointWithoutDrain);
-        } else if (syncSavepointWithDrain != null) {
-            return OptionalLong.of(syncSavepointWithDrain);
+        if (syncSavepoint != null) {
+            return OptionalLong.of(syncSavepoint);
         } else {
             return OptionalLong.empty();
         }
     }
 
-    private boolean isCurrentSavepointWithDrain(long checkpointId) {
-        return syncSavepointWithDrain != null && syncSavepointWithDrain == checkpointId;
-    }
-
-    private boolean isCurrentSavepointWithoutDrain(long checkpointId) {
-        return syncSavepointWithoutDrain != null && syncSavepointWithoutDrain == checkpointId;
-    }
-
-    private void runSynchronousSavepointMailboxLoop() throws Exception {
-        assert syncSavepointWithoutDrain != null;
-
-        MailboxExecutor mailboxExecutor =
-                mailboxProcessor.getMailboxExecutor(TaskMailbox.MAX_PRIORITY);
-
-        while (!canceled && syncSavepointWithoutDrain != null) {
-            mailboxExecutor.yield();
-        }
+    private boolean isCurrentSyncSavepoint(long checkpointId) {
+        return syncSavepoint != null && syncSavepoint == checkpointId;
     }
 
     /**
@@ -856,7 +826,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
             terminationConditions.add(finalCheckpointCompleted);
         }
 
-        if (syncSavepointWithDrain != null) {
+        if (syncSavepoint != null) {
             terminationConditions.add(finalCheckpointCompleted);
         }
 
@@ -1228,11 +1198,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 
         FlinkSecurityManager.monitorUserSystemExitForCurrentThread();
         try {
-            if (performCheckpoint(checkpointMetaData, checkpointOptions, checkpointMetrics)) {
-                if (isCurrentSavepointWithoutDrain(checkpointMetaData.getCheckpointId())) {
-                    runSynchronousSavepointMailboxLoop();
-                }
-            }
+            performCheckpoint(checkpointMetaData, checkpointOptions, checkpointMetrics);
         } catch (CancelTaskException e) {
             LOG.info(
                     "Operator {} was cancelled while performing checkpoint {}.",
@@ -1255,10 +1221,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
     @Override
     public void abortCheckpointOnBarrier(long checkpointId, CheckpointException cause)
             throws IOException {
-        if (isCurrentSavepointWithoutDrain(checkpointId)) {
-            syncSavepointWithoutDrain = null;
-        } else if (isCurrentSavepointWithDrain(checkpointId)) {
-            throw new FlinkRuntimeException("Stop-with-savepoint --drain failed.");
+        if (isCurrentSyncSavepoint(checkpointId)) {
+            throw new FlinkRuntimeException("Stop-with-savepoint failed.");
         }
         subtaskCheckpointCoordinator.abortCheckpointOnBarrier(checkpointId, cause, operatorChain);
     }
@@ -1276,21 +1240,11 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                 checkpointType,
                 getName());
 
-        if (checkpointType.isSynchronous()
-                && !checkpointType.shouldDrain()
-                && endOfDataReceived
-                && areCheckpointsWithFinishedTasksEnabled()) {
-            LOG.debug("Can not trigger a stop-with-savepoint w/o drain if a task is finishing.");
-            return false;
-        }
-
         if (isRunning) {
             actionExecutor.runThrowing(
                     () -> {
                         if (checkpointType.isSynchronous()) {
-                            setSynchronousSavepoint(
-                                    checkpointMetaData.getCheckpointId(),
-                                    checkpointType.shouldDrain());
+                            setSynchronousSavepoint(checkpointMetaData.getCheckpointId());
                         }
 
                         if (areCheckpointsWithFinishedTasksEnabled()
@@ -1357,10 +1311,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                         notifyCheckpointComplete(latestCompletedCheckpointId);
                     }
 
-                    if (isCurrentSavepointWithoutDrain(checkpointId)) {
-                        syncSavepointWithoutDrain = null;
-                    } else if (isCurrentSavepointWithDrain(checkpointId)) {
-                        throw new FlinkRuntimeException("Stop-with-savepoint --drain failed.");
+                    if (isCurrentSyncSavepoint(checkpointId)) {
+                        throw new FlinkRuntimeException("Stop-with-savepoint failed.");
                     }
                     subtaskCheckpointCoordinator.notifyCheckpointAborted(
                             checkpointId, operatorChain, this::isRunning);
@@ -1399,13 +1351,9 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
         subtaskCheckpointCoordinator.notifyCheckpointComplete(
                 checkpointId, operatorChain, this::isRunning);
         if (isRunning) {
-            if (isCurrentSavepointWithoutDrain(checkpointId)) {
-                finishTask();
-                // Reset to "notify" the internal synchronous savepoint mailbox loop.
-                syncSavepointWithoutDrain = null;
-            } else if (isCurrentSavepointWithDrain(checkpointId)) {
+            if (isCurrentSyncSavepoint(checkpointId)) {
                 finalCheckpointCompleted.complete(null);
-            } else if (syncSavepointWithDrain == null
+            } else if (syncSavepoint == null
                     && finalCheckpointMinId != null
                     && checkpointId >= finalCheckpointMinId) {
                 finalCheckpointCompleted.complete(null);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
index 32b6e96..cf7418d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputSerializer;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -124,7 +125,8 @@ public class StreamTestSingleInputGate<T> {
                         } else if (input != null && input.isDataEnd()) {
                             return Optional.of(
                                     new BufferAndAvailability(
-                                            EventSerializer.toBuffer(new EndOfData(true), false),
+                                            EventSerializer.toBuffer(
+                                                    new EndOfData(StopMode.DRAIN), false),
                                             nextType,
                                             0,
                                             0));
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/SourceOperatorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/SourceOperatorTest.java
index aa0fb7e..f318ccc 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/SourceOperatorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/SourceOperatorTest.java
@@ -29,6 +29,7 @@ import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.core.io.SimpleVersionedSerialization;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.io.AvailabilityProvider;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.operators.coordination.MockOperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
@@ -148,7 +149,7 @@ public class SourceOperatorTest {
         assertEquals(DataInputStatus.NOTHING_AVAILABLE, operator.emitNext(dataOutput));
         assertFalse(operator.isAvailable());
 
-        CompletableFuture<Void> sourceStopped = operator.stop();
+        CompletableFuture<Void> sourceStopped = operator.stop(StopMode.DRAIN);
         assertTrue(operator.isAvailable());
         assertFalse(sourceStopped.isDone());
         assertEquals(DataInputStatus.END_OF_DATA, operator.emitNext(dataOutput));
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
index 8bbe8fa..ae4721c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
@@ -96,13 +96,8 @@ public class MockIndexedInputGate extends IndexedInputGate {
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
-        return false;
-    }
-
-    @Override
-    public boolean shouldDrainOnEndOfData() {
-        return false;
+    public EndOfDataStatus hasReceivedEndOfData() {
+        return EndOfDataStatus.NOT_END_OF_DATA;
     }
 
     @Override
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index 993361e..1c095d6 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -114,12 +114,7 @@ public class MockInputGate extends IndexedInputGate {
     }
 
     @Override
-    public boolean hasReceivedEndOfData() {
-        throw new UnsupportedOperationException("Not implemented yet");
-    }
-
-    @Override
-    public boolean shouldDrainOnEndOfData() {
+    public EndOfDataStatus hasReceivedEndOfData() {
         throw new UnsupportedOperationException("Not implemented yet");
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
index 903ec71..42086fa 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
@@ -168,13 +168,8 @@ public class AlignedCheckpointsMassiveRandomTest {
         }
 
         @Override
-        public boolean hasReceivedEndOfData() {
-            return false;
-        }
-
-        @Override
-        public boolean shouldDrainOnEndOfData() {
-            return false;
+        public EndOfDataStatus hasReceivedEndOfData() {
+            return EndOfDataStatus.NOT_END_OF_DATA;
         }
 
         @Override
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
index d433b69..521dbd0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
@@ -32,6 +32,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordOrEventCollectingResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.PartitionTestUtils;
@@ -281,10 +282,10 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             CheckpointBarrier barrier = createStopWithSavepointDrainBarrier();
 
             testHarness.processElement(new StreamRecord<>("44", TimestampAssigner.NO_TIMESTAMP), 0);
-            testHarness.processEvent(new EndOfData(true), 0);
+            testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0);
             testHarness.processEvent(barrier, 0);
             testHarness.processElement(new StreamRecord<>(47d, TimestampAssigner.NO_TIMESTAMP), 1);
-            testHarness.processEvent(new EndOfData(true), 1);
+            testHarness.processEvent(new EndOfData(StopMode.DRAIN), 1);
             testHarness.processEvent(barrier, 1);
 
             addSourceRecords(testHarness, 1, Boundedness.CONTINUOUS_UNBOUNDED, 1, 2);
@@ -311,7 +312,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                     containsInAnyOrder(expectedOutput.toArray()));
             assertThat(
                     actualOutput.subList(actualOutput.size() - 3, actualOutput.size()),
-                    contains(new StreamRecord<>("FINISH"), new EndOfData(true), barrier));
+                    contains(new StreamRecord<>("FINISH"), new EndOfData(StopMode.DRAIN), barrier));
         }
     }
 
@@ -435,8 +436,8 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                 testHarness.processAll();
 
                 // The checkpoint 2 would be aligned after received all the EndOfPartitionEvent.
-                testHarness.processEvent(new EndOfData(true), 0, 0);
-                testHarness.processEvent(new EndOfData(true), 1, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.getTaskStateManager().getWaitForReportLatch().await();
@@ -493,9 +494,8 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData(boolean shouldDrain)
-                                            throws IOException {
-                                        broadcastEvent(new EndOfData(shouldDrain), false);
+                                    public void notifyEndOfData(StopMode mode) throws IOException {
+                                        broadcastEvent(new EndOfData(mode), false);
                                     }
                                 })
                         .addSourceInput(
@@ -523,7 +523,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             testHarness.processElement(Watermark.MAX_WATERMARK);
             assertThat(output, is(empty()));
             testHarness.waitForTaskCompletion();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(StopMode.DRAIN)));
 
             for (StreamOperatorWrapper<?, ?> wrapper :
                     testHarness.getStreamTask().operatorChain.getAllOperators()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
index c31dd3d..682e873 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
@@ -52,6 +52,7 @@ import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.PartitionTestUtils;
 import org.apache.flink.runtime.io.network.partition.ResultPartition;
@@ -989,15 +990,15 @@ public class MultipleInputStreamTaskTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
                 assertEquals(4, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after all the inputs have received EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 1, 0);
-                testHarness.processEvent(new EndOfData(true), 2, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 1, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 2, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 2, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -1059,7 +1060,7 @@ public class MultipleInputStreamTaskTest {
             testHarness.waitForTaskCompletion();
             assertThat(
                     testHarness.getOutput(),
-                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(StopMode.DRAIN)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
index d4fa9be..5c73c43 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
@@ -42,6 +42,7 @@ import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordOrEventCollectingResultPartitionWriter;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.source.event.AddSplitEvent;
@@ -129,7 +130,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(new EndOfData(true));
+            expectedOutput.add(new EndOfData(StopMode.DRAIN));
             expectedOutput.add(
                     new CheckpointBarrier(checkpointId, checkpointId, checkpointOptions));
 
@@ -145,7 +146,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(new EndOfData(true));
+            expectedOutput.add(new EndOfData(StopMode.DRAIN));
             assertThat(testHarness.getOutput().toArray(), equalTo(expectedOutput.toArray()));
         }
     }
@@ -203,9 +204,8 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData(boolean shouldDrain)
-                                            throws IOException {
-                                        broadcastEvent(new EndOfData(shouldDrain), false);
+                                    public void notifyEndOfData(StopMode mode) throws IOException {
+                                        broadcastEvent(new EndOfData(mode), false);
                                     }
                                 })
                         .setupOperatorChain(sourceOperatorFactory)
@@ -215,7 +215,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             testHarness.getStreamTask().invoke();
             testHarness.processAll();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(StopMode.DRAIN)));
 
             LifeCycleMonitorSourceReader sourceReader =
                     (LifeCycleMonitorSourceReader)
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
index 6553279..0a79089 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
@@ -28,7 +28,6 @@ import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.MultiShotLatch;
-import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
@@ -38,6 +37,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.RecordOrEventCollectingResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.PartitionTestUtils;
@@ -72,8 +72,6 @@ import org.apache.flink.util.function.CheckedSupplier;
 import org.junit.Assert;
 import org.junit.Test;
 
-import javax.annotation.Nonnull;
-
 import java.io.IOException;
 import java.io.Serializable;
 import java.util.ArrayList;
@@ -89,14 +87,11 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.Future;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.Semaphore;
-import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 
 import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO;
 import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.STRING_TYPE_INFO;
-import static org.apache.flink.runtime.checkpoint.CheckpointType.SAVEPOINT_SUSPEND;
 import static org.apache.flink.runtime.checkpoint.CheckpointType.SAVEPOINT_TERMINATE;
-import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
 import static org.apache.flink.streaming.runtime.tasks.StreamTaskFinalCheckpointsTest.triggerCheckpoint;
 import static org.apache.flink.util.Preconditions.checkState;
 import static org.hamcrest.CoreMatchers.is;
@@ -113,39 +108,6 @@ import static org.junit.Assert.assertTrue;
  */
 public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
 
-    @Test
-    public void testInputEndedBeforeStopWithSavepointConfirmed() throws Exception {
-        CancelTestSource source =
-                new CancelTestSource(
-                        STRING_TYPE_INFO.createSerializer(new ExecutionConfig()), "src");
-        TestBoundedOneInputStreamOperator chainTail = new TestBoundedOneInputStreamOperator("t");
-        StreamTaskMailboxTestHarness<String> harness =
-                new StreamTaskMailboxTestHarnessBuilder<>(SourceStreamTask::new, STRING_TYPE_INFO)
-                        .setupOperatorChain(
-                                new OperatorID(),
-                                new StreamSource<String, CancelTestSource<String>>(source))
-                        .chain(
-                                new OperatorID(),
-                                chainTail,
-                                STRING_TYPE_INFO.createSerializer(new ExecutionConfig()))
-                        .finish()
-                        .build();
-        Future<Boolean> triggerFuture =
-                harness.streamTask.triggerCheckpointAsync(
-                        new CheckpointMetaData(1, 1),
-                        new CheckpointOptions(SAVEPOINT_SUSPEND, getDefault()));
-        while (!triggerFuture.isDone()) {
-            harness.streamTask.runMailboxStep();
-        }
-        // instead of completing stop with savepoint via `notifyCheckpointCompleted`
-        // we simulate that source has finished first. As a result, we expect that the endInput
-        // should have been issued
-        source.cancel();
-        harness.streamTask.invoke();
-        harness.waitForTaskCompletion();
-        assertTrue(TestBoundedOneInputStreamOperator.isInputEnded());
-    }
-
     /** This test verifies that open() and close() are correctly called by the StreamTask. */
     @Test
     public void testOpenClose() throws Exception {
@@ -528,27 +490,6 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
         public void cancel() {}
     }
 
-    /** If finishing a task doesn't swallow exceptions this test would fail with an exception. */
-    @Test
-    public void finishingIgnoresExceptions() throws Exception {
-        final StreamTaskTestHarness<String> testHarness =
-                new StreamTaskTestHarness<>(SourceStreamTask::new, STRING_TYPE_INFO);
-
-        final CompletableFuture<Void> operatorRunningWaitingFuture = new CompletableFuture<>();
-        ExceptionThrowingSource.setIsInRunLoopFuture(operatorRunningWaitingFuture);
-
-        testHarness.setupOutputForSingletonOperatorChain();
-        StreamConfig streamConfig = testHarness.getStreamConfig();
-        streamConfig.setStreamOperator(new StreamSource<>(new ExceptionThrowingSource()));
-        streamConfig.setOperatorID(new OperatorID());
-
-        testHarness.invoke();
-        operatorRunningWaitingFuture.get();
-        testHarness.getTask().finishTask();
-
-        testHarness.waitForTaskCompletion();
-    }
-
     @Test
     public void testWaitsForSourceThreadOnCancel() throws Exception {
         StreamTaskTestHarness<String> harness =
@@ -578,45 +519,6 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
     }
 
     @Test
-    public void testStopWithSavepointShouldNotInterruptTheSource() throws Exception {
-        long checkpointId = 1;
-        WasInterruptedTestingSource interruptedTestingSource = new WasInterruptedTestingSource();
-        WasInterruptedTestingSource.reset();
-        try (StreamTaskMailboxTestHarness<String> harness =
-                new StreamTaskMailboxTestHarnessBuilder<>(SourceStreamTask::new, STRING_TYPE_INFO)
-                        .setupOutputForSingletonOperatorChain(
-                                new StreamSource<>(interruptedTestingSource))
-                        .build()) {
-
-            harness.processAll();
-
-            Future<Boolean> triggerFuture =
-                    harness.streamTask.triggerCheckpointAsync(
-                            new CheckpointMetaData(checkpointId, 1),
-                            new CheckpointOptions(SAVEPOINT_SUSPEND, getDefault()));
-            while (!triggerFuture.isDone()) {
-                harness.streamTask.runMailboxStep();
-            }
-            triggerFuture.get();
-
-            Future<Void> notifyFuture =
-                    harness.streamTask.notifyCheckpointCompleteAsync(checkpointId);
-            while (!notifyFuture.isDone()) {
-                harness.streamTask.runMailboxStep();
-            }
-            notifyFuture.get();
-
-            WasInterruptedTestingSource.allowExit();
-
-            harness.waitForTaskCompletion();
-            harness.finishProcessing();
-
-            assertTrue(notifyFuture.isDone());
-            assertFalse(interruptedTestingSource.wasInterrupted());
-        }
-    }
-
-    @Test
     public void testTriggeringCheckpointAfterSourceThreadFinished() throws Exception {
         ResultPartition[] partitionWriters = new ResultPartition[2];
         try (NettyShuffleEnvironment env =
@@ -720,9 +622,8 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData(boolean shouldDrain)
-                                            throws IOException {
-                                        broadcastEvent(new EndOfData(shouldDrain), false);
+                                    public void notifyEndOfData(StopMode mode) throws IOException {
+                                        broadcastEvent(new EndOfData(mode), false);
                                     }
                                 })
                         .setupOperatorChain(new StreamSource<>(testSource))
@@ -733,7 +634,7 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
             harness.processAll();
             harness.streamTask.getCompletionFuture().get();
 
-            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(StopMode.DRAIN)));
 
             LifeCycleMonitorSource source =
                     (LifeCycleMonitorSource)
@@ -1036,47 +937,6 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
         }
     }
 
-    /**
-     * A {@link SourceFunction} that throws an exception from {@link #run(SourceContext)} when it is
-     * cancelled via {@link #cancel()}.
-     */
-    private static class ExceptionThrowingSource implements SourceFunction<String> {
-
-        private static volatile CompletableFuture<Void> isInRunLoop;
-
-        private volatile boolean running = true;
-
-        public static class TestException extends RuntimeException {
-            public TestException(String message) {
-                super(message);
-            }
-        }
-
-        public static void setIsInRunLoopFuture(
-                @Nonnull final CompletableFuture<Void> waitingLatch) {
-            ExceptionThrowingSource.isInRunLoop = waitingLatch;
-        }
-
-        @Override
-        public void run(SourceContext<String> ctx) throws TestException {
-            checkState(isInRunLoop != null && !isInRunLoop.isDone());
-
-            while (running) {
-                if (!isInRunLoop.isDone()) {
-                    isInRunLoop.complete(null);
-                }
-                ctx.collect("hello");
-            }
-
-            throw new TestException("Oh no, we're failing.");
-        }
-
-        @Override
-        public void cancel() {
-            running = false;
-        }
-    }
-
     private static final class OutputRecordInCloseTestSource<SRC extends SourceFunction<String>>
             extends StreamSource<String, SRC> implements BoundedOneInput {
 
@@ -1108,48 +968,6 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
         }
     }
 
-    /**
-     * This source sleeps a little bit before processing cancellation and records whether it was
-     * interrupted by the {@link SourceStreamTask} or not.
-     */
-    private static class WasInterruptedTestingSource implements SourceFunction<String> {
-        private static final long serialVersionUID = 1L;
-
-        private static final OneShotLatch ALLOW_EXIT = new OneShotLatch();
-        private static final AtomicBoolean WAS_INTERRUPTED = new AtomicBoolean();
-
-        private volatile boolean running = true;
-
-        @Override
-        public void run(SourceContext<String> ctx) throws Exception {
-            try {
-                while (running || !ALLOW_EXIT.isTriggered()) {
-                    Thread.sleep(1);
-                }
-            } catch (InterruptedException e) {
-                WAS_INTERRUPTED.set(true);
-            }
-        }
-
-        @Override
-        public void cancel() {
-            running = false;
-        }
-
-        public static boolean wasInterrupted() {
-            return WAS_INTERRUPTED.get();
-        }
-
-        public static void reset() {
-            ALLOW_EXIT.reset();
-            WAS_INTERRUPTED.set(false);
-        }
-
-        public static void allowExit() {
-            ALLOW_EXIT.trigger();
-        }
-    }
-
     private static class EmptySource implements SourceFunction<String> {
 
         private volatile boolean isCanceled;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
index 5a67e86..c963f3e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointType;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.StreamSource;
@@ -40,7 +41,6 @@ import org.junit.Test;
 import java.util.Queue;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 /**
@@ -107,9 +107,11 @@ public class SourceTaskTerminationTest extends TestLogger {
                 // if we are in TERMINATE mode, we expect the source task
                 // to emit MAX_WM before the SYNC_SAVEPOINT barrier.
                 verifyWatermark(srcTaskTestHarness.getOutput(), Watermark.MAX_WATERMARK);
-                verifyEvent(srcTaskTestHarness.getOutput(), new EndOfData(true));
             }
 
+            verifyEvent(
+                    srcTaskTestHarness.getOutput(),
+                    new EndOfData(shouldTerminate ? StopMode.DRAIN : StopMode.NO_DRAIN));
             verifyCheckpointBarrier(srcTaskTestHarness.getOutput(), syncSavepointId);
 
             waitForSynchronousSavepointIdToBeSet(srcTask);
@@ -118,9 +120,6 @@ public class SourceTaskTerminationTest extends TestLogger {
 
             srcTaskTestHarness.processUntil(
                     srcTask.notifyCheckpointCompleteAsync(syncSavepointId)::isDone);
-            if (!shouldTerminate) {
-                assertFalse(srcTask.getSynchronousSavepointId().isPresent());
-            }
 
             srcTaskTestHarness.waitForTaskCompletion();
         }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapperTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapperTest.java
index 0c39909..b255847 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapperTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamOperatorWrapperTest.java
@@ -20,6 +20,7 @@ package org.apache.flink.streaming.runtime.tasks;
 
 import org.apache.flink.api.common.operators.MailboxExecutor;
 import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
@@ -133,7 +134,7 @@ public class StreamOperatorWrapperTest extends TestLogger {
     @Test
     public void testFinish() throws Exception {
         output.clear();
-        operatorWrappers.get(0).finish(containingTask.getActionExecutor());
+        operatorWrappers.get(0).finish(containingTask.getActionExecutor(), StopMode.DRAIN);
 
         List<Object> expected = new ArrayList<>();
         for (int i = 0; i < operatorWrappers.size(); i++) {
@@ -172,7 +173,7 @@ public class StreamOperatorWrapperTest extends TestLogger {
                         true);
 
         try {
-            operatorWrapper.finish(containingTask.getActionExecutor());
+            operatorWrapper.finish(containingTask.getActionExecutor(), StopMode.DRAIN);
             fail("should throw an exception");
         } catch (Throwable t) {
             Optional<Throwable> optional =
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
index 88087d5..77df7e7 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
@@ -23,7 +23,6 @@ import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
-import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
@@ -34,6 +33,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.PartitionTestUtils;
 import org.apache.flink.runtime.io.network.partition.PipelinedResultPartition;
@@ -62,7 +62,6 @@ import java.util.concurrent.Future;
 
 import static org.apache.flink.api.common.typeinfo.BasicTypeInfo.STRING_TYPE_INFO;
 import static org.apache.flink.runtime.state.CheckpointStorageLocationReference.getDefault;
-import static org.apache.flink.util.ExceptionUtils.assertThrowable;
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.contains;
 import static org.junit.Assert.assertArrayEquals;
@@ -86,7 +85,8 @@ public class StreamTaskFinalCheckpointsTest {
         harness.setAutoProcess(false);
         harness.processElement(new StreamRecord<>(1));
 
-        harness.streamTask.operatorChain.finishOperators(harness.streamTask.getActionExecutor());
+        harness.streamTask.operatorChain.finishOperators(
+                harness.streamTask.getActionExecutor(), StopMode.DRAIN);
         assertTrue(FinishingOperator.finished);
 
         harness.getTaskStateManager().getWaitForReportLatch().reset();
@@ -155,7 +155,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -163,8 +163,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 0, 1);
-                testHarness.processEvent(new EndOfData(true), 0, 2);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 1);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, lastCheckpointId);
@@ -400,6 +400,17 @@ public class StreamTaskFinalCheckpointsTest {
     @Test
     public void testTriggerStopWithSavepointWhenWaitingForFinalCheckpointOnSourceTask()
             throws Exception {
+        doTestTriggerStopWithSavepointWhenWaitingForFinalCheckpointOnSourceTask(true);
+    }
+
+    @Test
+    public void testTriggerStopWithSavepointNoDrainWhenWaitingForFinalCheckpointOnSourceTask()
+            throws Exception {
+        doTestTriggerStopWithSavepointWhenWaitingForFinalCheckpointOnSourceTask(false);
+    }
+
+    private void doTestTriggerStopWithSavepointWhenWaitingForFinalCheckpointOnSourceTask(
+            boolean drain) throws Exception {
         int finalCheckpointId = 6;
         int syncSavepointId = 7;
         CompletingCheckpointResponder checkpointResponder =
@@ -468,7 +479,9 @@ public class StreamTaskFinalCheckpointsTest {
 
             // trigger the synchronous savepoint
             CompletableFuture<Boolean> savepointFuture =
-                    triggerStopWithSavepointDrain(testHarness, syncSavepointId);
+                    drain
+                            ? triggerStopWithSavepointDrain(testHarness, syncSavepointId)
+                            : triggerStopWithSavepointNoDrain(testHarness, syncSavepointId);
 
             // The checkpoint 6 would be triggered successfully.
             testHarness.finishProcessing();
@@ -484,130 +497,6 @@ public class StreamTaskFinalCheckpointsTest {
     }
 
     @Test
-    public void testTriggerStopWithSavepointNoDrainWhenWaitingForFinalCheckpointOnSourceTask()
-            throws Exception {
-        int finalCheckpointId = 6;
-        int syncSavepointId = 7;
-        CompletingCheckpointResponder checkpointResponder =
-                new CompletingCheckpointResponder() {
-
-                    private CheckpointMetrics metrics;
-                    private TaskStateSnapshot stateSnapshot;
-
-                    @Override
-                    public void acknowledgeCheckpoint(
-                            JobID jobID,
-                            ExecutionAttemptID executionAttemptID,
-                            long checkpointId,
-                            CheckpointMetrics checkpointMetrics,
-                            TaskStateSnapshot subtaskState) {
-                        // do not acknowledge any checkpoints straightaway
-                        if (checkpointId == finalCheckpointId) {
-                            metrics = checkpointMetrics;
-                            stateSnapshot = subtaskState;
-                        }
-                    }
-
-                    @Override
-                    public void declineCheckpoint(
-                            JobID jobID,
-                            ExecutionAttemptID executionAttemptID,
-                            long checkpointId,
-                            CheckpointException checkpointException) {
-                        // acknowledge the last pending checkpoint once the synchronous savepoint is
-                        // declined.
-                        if (syncSavepointId == checkpointId) {
-                            super.acknowledgeCheckpoint(
-                                    jobID,
-                                    executionAttemptID,
-                                    finalCheckpointId,
-                                    metrics,
-                                    stateSnapshot);
-                        }
-                    }
-                };
-
-        try (StreamTaskMailboxTestHarness<String> testHarness =
-                new StreamTaskMailboxTestHarnessBuilder<>(SourceStreamTask::new, STRING_TYPE_INFO)
-                        .modifyStreamConfig(
-                                config -> {
-                                    config.setCheckpointingEnabled(true);
-                                    config.getConfiguration()
-                                            .set(
-                                                    ExecutionCheckpointingOptions
-                                                            .ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
-                                                    true);
-                                })
-                        .setCheckpointResponder(checkpointResponder)
-                        .setupOutputForSingletonOperatorChain(
-                                new StreamSource<>(new ImmediatelyFinishingSource()))
-                        .build()) {
-            checkpointResponder.setHandlers(
-                    testHarness.streamTask::notifyCheckpointCompleteAsync,
-                    testHarness.streamTask::notifyCheckpointAbortAsync);
-
-            // start task thread
-            testHarness.streamTask.runMailboxLoop();
-
-            // trigger the final checkpoint
-            CompletableFuture<Boolean> checkpointFuture =
-                    triggerCheckpoint(testHarness, finalCheckpointId);
-
-            // trigger the synchronous savepoint w/o drain, which should be declined
-            CompletableFuture<Boolean> savepointFuture =
-                    triggerStopWithSavepointNoDrain(testHarness, syncSavepointId);
-
-            // The checkpoint 6 would be triggered successfully.
-            testHarness.finishProcessing();
-            assertTrue(checkpointFuture.isDone());
-            assertTrue(savepointFuture.isDone());
-            assertFalse(savepointFuture.get());
-            testHarness.getTaskStateManager().getWaitForReportLatch().await();
-            assertEquals(
-                    finalCheckpointId, testHarness.getTaskStateManager().getReportedCheckpointId());
-            assertEquals(
-                    finalCheckpointId,
-                    testHarness.getTaskStateManager().getNotifiedCompletedCheckpointId());
-        }
-    }
-
-    @Test
-    public void testTriggerSourceFinishesWhileStoppingWithSavepointWithoutDrain() throws Exception {
-        try (StreamTaskMailboxTestHarness<String> testHarness =
-                new StreamTaskMailboxTestHarnessBuilder<>(SourceStreamTask::new, STRING_TYPE_INFO)
-                        .modifyStreamConfig(
-                                config -> {
-                                    config.setCheckpointingEnabled(true);
-                                    config.getConfiguration()
-                                            .set(
-                                                    ExecutionCheckpointingOptions
-                                                            .ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
-                                                    true);
-                                })
-                        .setupOutputForSingletonOperatorChain(
-                                new StreamSource<>(new ImmediatelyFinishingSource()))
-                        .build()) {
-
-            // trigger the synchronous savepoint w/o drain
-            triggerStopWithSavepointNoDrain(testHarness, 1);
-
-            // start task thread
-            testHarness.streamTask.runMailboxLoop();
-        } catch (Exception ex) {
-            assertThrowable(
-                    ex,
-                    (e ->
-                            e.getMessage()
-                                    .equals(
-                                            "We run out of data to process while waiting for a "
-                                                    + "synchronous savepoint to be finished. This "
-                                                    + "can lead to a deadlock waiting for a final "
-                                                    + "checkpoint after a synchronous savepoint, "
-                                                    + "which will never be triggered.")));
-        }
-    }
-
-    @Test
     public void testTriggeringAlignedNoTimeoutCheckpointWithFinishedChannels() throws Exception {
         testTriggeringCheckpointWithFinishedChannels(
                 CheckpointOptions.alignedNoTimeout(
@@ -664,7 +553,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertArrayEquals(new int[] {0, 0, 0}, resumedCount);
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -673,8 +562,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(new EndOfData(true), 0, 1);
-                testHarness.processEvent(new EndOfData(true), 0, 2);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 1);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -759,7 +648,7 @@ public class StreamTaskFinalCheckpointsTest {
                 // The checkpoint is added to the mailbox and will be processed in the
                 // mailbox loop after call operators' finish method in the afterInvoke()
                 // method.
-                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(StopMode.DRAIN), 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 checkpointFuture.thenAccept(
                         (ignored) -> {
@@ -947,7 +836,7 @@ public class StreamTaskFinalCheckpointsTest {
                                     checkpointMetaData.getTimestamp(),
                                     checkpointOptions),
                             Watermark.MAX_WATERMARK,
-                            new EndOfData(true)));
+                            new EndOfData(StopMode.DRAIN)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 3683b2f..9de129c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -46,6 +46,7 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
@@ -232,6 +233,8 @@ public class StreamTaskTest extends TestLogger {
 
     @Test
     public void testSavepointSuspendedAborted() throws Exception {
+        thrown.expect(FlinkRuntimeException.class);
+        thrown.expectMessage("Stop-with-savepoint failed.");
         testSyncSavepointWithEndInput(
                 (task, id) ->
                         task.abortCheckpointOnBarrier(
@@ -245,7 +248,7 @@ public class StreamTaskTest extends TestLogger {
     @Test
     public void testSavepointTerminateAborted() throws Exception {
         thrown.expect(FlinkRuntimeException.class);
-        thrown.expectMessage("Stop-with-savepoint --drain failed.");
+        thrown.expectMessage("Stop-with-savepoint failed.");
         testSyncSavepointWithEndInput(
                 (task, id) ->
                         task.abortCheckpointOnBarrier(
@@ -258,6 +261,8 @@ public class StreamTaskTest extends TestLogger {
 
     @Test
     public void testSavepointSuspendAbortedAsync() throws Exception {
+        thrown.expect(FlinkRuntimeException.class);
+        thrown.expectMessage("Stop-with-savepoint failed.");
         testSyncSavepointWithEndInput(
                 (streamTask, abortCheckpointId) ->
                         streamTask.notifyCheckpointAbortAsync(abortCheckpointId, 0),
@@ -268,7 +273,7 @@ public class StreamTaskTest extends TestLogger {
     @Test
     public void testSavepointTerminateAbortedAsync() throws Exception {
         thrown.expect(FlinkRuntimeException.class);
-        thrown.expectMessage("Stop-with-savepoint --drain failed.");
+        thrown.expectMessage("Stop-with-savepoint failed.");
         testSyncSavepointWithEndInput(
                 (streamTask, abortCheckpointId) ->
                         streamTask.notifyCheckpointAbortAsync(abortCheckpointId, 0),
@@ -990,7 +995,8 @@ public class StreamTaskTest extends TestLogger {
         assertFalse(ClosingOperator.closed.get());
 
         // close operators directly, so that task is still fully running
-        harness.streamTask.operatorChain.finishOperators(harness.streamTask.getActionExecutor());
+        harness.streamTask.operatorChain.finishOperators(
+                harness.streamTask.getActionExecutor(), StopMode.DRAIN);
         harness.streamTask.operatorChain.closeAllOperators();
         harness.streamTask.notifyCheckpointCompleteAsync(2);
         harness.streamTask.runMailboxStep();
@@ -1052,7 +1058,7 @@ public class StreamTaskTest extends TestLogger {
             harness.processElement(new StreamRecord<>(1));
 
             harness.streamTask.operatorChain.finishOperators(
-                    harness.streamTask.getActionExecutor());
+                    harness.streamTask.getActionExecutor(), StopMode.DRAIN);
             harness.streamTask.operatorChain.closeAllOperators();
             assertTrue(ClosingOperator.closed.get());
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index a95b17c..ff415f3 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -34,6 +34,7 @@ import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
@@ -492,7 +493,7 @@ public class StreamTaskTestHarness<OUT> {
 
     public void endInput(int gateIndex, int channelIndex, boolean emitEndOfData) {
         if (emitEndOfData) {
-            inputGates[gateIndex].sendEvent(new EndOfData(true), channelIndex);
+            inputGates[gateIndex].sendEvent(new EndOfData(StopMode.DRAIN), channelIndex);
         }
         inputGates[gateIndex].sendEvent(EndOfPartitionEvent.INSTANCE, channelIndex);
     }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointTest.java
index 3194dfa..100b89b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SynchronousCheckpointTest.java
@@ -39,7 +39,6 @@ import java.util.concurrent.LinkedBlockingQueue;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
@@ -76,24 +75,6 @@ public class SynchronousCheckpointTest {
         assertThat(eventQueue.take(), is(Event.TASK_INITIALIZED));
     }
 
-    @Test(timeout = 20_000)
-    public void synchronousCheckpointBlocksUntilNotificationForCorrectCheckpointComes()
-            throws Exception {
-        launchSynchronousSavepointAndWaitForSyncSavepointIdToBeSet();
-        assertTrue(streamTaskUnderTest.getSynchronousSavepointId().isPresent());
-
-        streamTaskUnderTest.notifyCheckpointCompleteAsync(41).get();
-        assertTrue(streamTaskUnderTest.getSynchronousSavepointId().isPresent());
-
-        streamTaskUnderTest.notifyCheckpointCompleteAsync(42).get();
-        assertFalse(streamTaskUnderTest.getSynchronousSavepointId().isPresent());
-
-        streamTaskUnderTest.stopTask();
-        waitUntilMainExecutionThreadIsFinished();
-
-        assertFalse(streamTaskUnderTest.isCanceled());
-    }
-
     @Test(timeout = 10_000)
     public void cancelShouldAlsoCancelPendingSynchronousCheckpoint() throws Throwable {
         launchSynchronousSavepointAndWaitForSyncSavepointIdToBeSet();
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index ab8304f..e32d67c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.network.api.CancelCheckpointMarker;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfData;
+import org.apache.flink.runtime.io.network.api.StopMode;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.metrics.MetricNames;
@@ -543,7 +544,7 @@ public class TwoInputStreamTaskTest {
             testHarness.waitForTaskCompletion();
             assertThat(
                     testHarness.getOutput(),
-                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(StopMode.DRAIN)));
         }
     }
 
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
index 2164968..1ecdce8 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/jobmaster/JobMasterStopWithSavepointITCase.java
@@ -54,10 +54,8 @@ import org.junit.rules.TemporaryFolder;
 import java.io.IOException;
 import java.nio.file.Files;
 import java.nio.file.Path;
-import java.nio.file.Paths;
 import java.time.Duration;
 import java.util.Collections;
-import java.util.List;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CountDownLatch;
@@ -65,13 +63,10 @@ import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Future;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicLong;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
 
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.either;
 import static org.hamcrest.Matchers.equalTo;
-import static org.hamcrest.Matchers.hasItem;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
@@ -100,31 +95,6 @@ public class JobMasterStopWithSavepointITCase extends AbstractTestBase {
     private JobGraph jobGraph;
 
     @Test
-    public void suspendWithSavepointWithoutComplicationsShouldSucceedAndLeadJobToFinished()
-            throws Exception {
-        stopWithSavepointNormalExecutionHelper(false);
-    }
-
-    private void stopWithSavepointNormalExecutionHelper(final boolean terminate) throws Exception {
-        setUpJobGraph(NoOpBlockingStreamTask.class, RestartStrategies.noRestart());
-
-        final CompletableFuture<String> savepointLocationFuture = stopWithSavepoint(terminate);
-
-        assertThat(getJobStatus(), equalTo(JobStatus.RUNNING));
-
-        finishingLatch.trigger();
-
-        final String savepointLocation = savepointLocationFuture.get();
-        assertThat(getJobStatus(), equalTo(JobStatus.FINISHED));
-
-        final List<Path> savepoints;
-        try (Stream<Path> savepointFiles = Files.list(savepointDirectory)) {
-            savepoints = savepointFiles.map(Path::getFileName).collect(Collectors.toList());
-        }
-        assertThat(savepoints, hasItem(Paths.get(savepointLocation).getFileName()));
-    }
-
-    @Test
     public void throwingExceptionOnCallbackWithNoRestartsShouldFailTheSuspend() throws Exception {
         throwingExceptionOnCallbackWithoutRestartsHelper(false);
     }
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/BoundedSourceITCase.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/BoundedSourceITCase.java
index ffeb08d..d913799 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/BoundedSourceITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/BoundedSourceITCase.java
@@ -76,6 +76,6 @@ public class BoundedSourceITCase extends AbstractTestBase {
                 .assertFinishedSuccessfully();
 
         checkOperatorsLifecycle(testJob, new DrainingValidator(), new FinishingValidator());
-        checkDataFlow(testJob);
+        checkDataFlow(testJob, true);
     }
 }
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/PartiallyFinishedSourcesITCase.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/PartiallyFinishedSourcesITCase.java
index ef4ac66..7a29238 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/PartiallyFinishedSourcesITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/PartiallyFinishedSourcesITCase.java
@@ -136,7 +136,7 @@ public class PartiallyFinishedSourcesITCase extends TestLogger {
                 .assertFinishedSuccessfully();
 
         checkOperatorsLifecycle(testJob, new DrainingValidator(), new FinishingValidator());
-        checkDataFlow(testJob);
+        checkDataFlow(testJob, true);
     }
 
     private TestJobWithDescription buildJob() throws Exception {
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/StopWithSavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/StopWithSavepointITCase.java
index a787eed..c3befd6 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/StopWithSavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/StopWithSavepointITCase.java
@@ -61,8 +61,7 @@ import static org.apache.flink.runtime.operators.lifecycle.validation.TestOperat
  *         <li>{@link Watermark#MAX_WATERMARK MAX_WATERMARK} (if with drain)
  *         <li>{@link BoundedMultiInput#endInput endInput} (if with drain)
  *         <li>timer service quiesced
- *         <li>{@link StreamOperator#finish() finish} (if with drain; support is planned for
- *             no-drain)
+ *         <li>{@link StreamOperator#finish() finish} (if with drain)
  *         <li>{@link AbstractStreamOperator#snapshotState(StateSnapshotContext) snapshotState} (for
  *             the respective checkpoint)
  *         <li>{@link CheckpointListener#notifyCheckpointComplete notifyCheckpointComplete} (for the
@@ -131,17 +130,13 @@ public class StopWithSavepointITCase extends AbstractTestBase {
                     testJob,
                     sameCheckpointValidator,
                     new DrainingValidator(),
-                    /* Currently (1.14), finish is only called with drain; todo: enable after updating production code */
+                    /* Currently (1.14), finish is only called with drain; */
                     new FinishingValidator());
         } else {
             checkOperatorsLifecycle(testJob, sameCheckpointValidator);
         }
 
-        if (withDrain) {
-            // currently (1.14), sources do not stop before taking a savepoint and continue emission
-            // todo: enable after updating production code
-            checkDataFlow(testJob);
-        }
+        checkDataFlow(testJob, withDrain);
     }
 
     @Parameterized.Parameters(name = "withDrain: {0}, {1}")
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/graph/TestEventSource.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/graph/TestEventSource.java
index 1badb6a..4259663 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/graph/TestEventSource.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/graph/TestEventSource.java
@@ -20,8 +20,6 @@ package org.apache.flink.runtime.operators.lifecycle.graph;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.lifecycle.command.TestCommand;
 import org.apache.flink.runtime.operators.lifecycle.command.TestCommandDispatcher;
-import org.apache.flink.runtime.operators.lifecycle.event.OperatorFinishedEvent;
-import org.apache.flink.runtime.operators.lifecycle.event.OperatorFinishedEvent.LastReceivedVertexDataInfo;
 import org.apache.flink.runtime.operators.lifecycle.event.OperatorStartedEvent;
 import org.apache.flink.runtime.operators.lifecycle.event.TestCommandAckEvent;
 import org.apache.flink.runtime.operators.lifecycle.event.TestEvent;
@@ -33,7 +31,6 @@ import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import java.util.Queue;
 import java.util.concurrent.LinkedBlockingQueue;
 
-import static java.util.Collections.emptyMap;
 import static org.apache.flink.runtime.operators.lifecycle.command.TestCommand.FAIL;
 import static org.apache.flink.runtime.operators.lifecycle.command.TestCommand.FINISH_SOURCES;
 
@@ -93,16 +90,6 @@ class TestEventSource extends RichSourceFunction<TestDataElement>
                 throw new RuntimeException("unknown command " + cmd);
             }
         }
-        // note: this only gets collected with FLIP-147 changes
-        synchronized (ctx.getCheckpointLock()) {
-            eventQueue.add(
-                    new OperatorFinishedEvent(
-                            operatorID,
-                            getRuntimeContext().getIndexOfThisSubtask(),
-                            getRuntimeContext().getAttemptNumber(),
-                            lastSent,
-                            new LastReceivedVertexDataInfo(emptyMap())));
-        }
     }
 
     private void ack(TestCommand cmd) {
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
index 7a8f789..137e504 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
@@ -36,13 +36,14 @@ import java.util.Optional;
 
 import static java.lang.String.format;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 
 /** Check that all data from the upstream reached the respective downstreams. */
 public class TestJobDataFlowValidator {
     private static final Logger LOG = LoggerFactory.getLogger(TestJobDataFlowValidator.class);
 
-    public static void checkDataFlow(TestJobWithDescription testJob) {
+    public static void checkDataFlow(TestJobWithDescription testJob, boolean withDrain) {
         Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents = new HashMap<>();
         for (TestEvent ev : testJob.eventQueue.getAll()) {
             if (ev instanceof OperatorFinishedEvent) {
@@ -55,11 +56,22 @@ public class TestJobDataFlowValidator {
         for (JobVertex upstream : testJob.jobGraph.getVertices()) {
             for (IntermediateDataSet produced : upstream.getProducedDataSets()) {
                 for (JobEdge edge : produced.getConsumers()) {
-                    Optional<String> upstreamID = getTrackedOperatorID(upstream, true, testJob);
-                    Optional<String> downstreamID =
+                    Optional<String> upstreamIDOptional =
+                            getTrackedOperatorID(upstream, true, testJob);
+                    Optional<String> downstreamIDOptional =
                             getTrackedOperatorID(edge.getTarget(), false, testJob);
-                    if (upstreamID.isPresent() && downstreamID.isPresent()) {
-                        checkDataFlow(upstreamID.get(), downstreamID.get(), edge, finishEvents);
+                    if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
+                        final String upstreamID = upstreamIDOptional.get();
+                        final String downstreamID = downstreamIDOptional.get();
+                        if (testJob.sources.contains(upstreamID)) {
+                            // TODO: if we add tests for FLIP-27 sources we might need to adjust
+                            // this condition
+                            LOG.debug(
+                                    "Legacy sources do not have the finish() method and thus do not"
+                                            + " emit FinishEvent");
+                        } else {
+                            checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+                        }
                     } else {
                         LOG.debug("Ignoring edge (untracked operator): {}", edge);
                     }
@@ -76,7 +88,8 @@ public class TestJobDataFlowValidator {
             String upstreamID,
             String downstreamID,
             JobEdge edge,
-            Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents) {
+            Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents,
+            boolean withDrain) {
         LOG.debug(
                 "Checking {} edge\n  from {} ({})\n  to {} ({})",
                 edge.getDistributionPattern(),
@@ -86,25 +99,27 @@ public class TestJobDataFlowValidator {
                 downstreamID);
 
         Map<Integer, OperatorFinishedEvent> downstreamFinishInfo =
-                getForOperator(downstreamID, finishEvents);
+                getForOperator(downstreamID, finishEvents, withDrain);
 
         Map<Integer, OperatorFinishedEvent> upstreamFinishInfo =
-                getForOperator(upstreamID, finishEvents);
+                getForOperator(upstreamID, finishEvents, withDrain);
 
-        upstreamFinishInfo.forEach(
-                (upstreamIndex, upstreamInfo) ->
-                        assertTrue(
-                                String.format(
-                                        "No downstream received %s from %s[%d]; received: %s",
-                                        upstreamInfo.lastSent,
-                                        upstreamID,
-                                        upstreamIndex,
-                                        downstreamFinishInfo),
-                                anySubtaskReceived(
-                                        upstreamID,
-                                        upstreamIndex,
-                                        upstreamInfo.lastSent,
-                                        downstreamFinishInfo.values())));
+        if (withDrain) {
+            upstreamFinishInfo.forEach(
+                    (upstreamIndex, upstreamInfo) ->
+                            assertTrue(
+                                    String.format(
+                                            "No downstream received %s from %s[%d]; received: %s",
+                                            upstreamInfo.lastSent,
+                                            upstreamID,
+                                            upstreamIndex,
+                                            downstreamFinishInfo),
+                                    anySubtaskReceived(
+                                            upstreamID,
+                                            upstreamIndex,
+                                            upstreamInfo.lastSent,
+                                            downstreamFinishInfo.values())));
+        }
     }
 
     private static boolean anySubtaskReceived(
@@ -118,13 +133,23 @@ public class TestJobDataFlowValidator {
     }
 
     private static Map<Integer, OperatorFinishedEvent> getForOperator(
-            String operatorId, Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents) {
+            String operatorId,
+            Map<String, Map<Integer, OperatorFinishedEvent>> finishEvents,
+            boolean withDrain) {
         Map<Integer, OperatorFinishedEvent> events = finishEvents.get(operatorId);
-        assertNotNull(
-                format(
-                        "Operator finish info wasn't collected: %s (collected: %s)",
-                        operatorId, finishEvents),
-                events);
+        if (withDrain) {
+            assertNotNull(
+                    format(
+                            "Operator finish info wasn't collected with draining: %s (collected: %s)",
+                            operatorId, finishEvents),
+                    events);
+        } else {
+            assertNull(
+                    format(
+                            "Operator finish info was collected without draining: %s (collected: %s)",
+                            operatorId, finishEvents),
+                    events);
+        }
         return events;
     }
 
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index b8348f9..b791873 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -24,6 +24,7 @@ import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.CheckpointListener;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.common.time.Deadline;
@@ -733,7 +734,7 @@ public class SavepointITCase extends TestLogger {
     @Test
     public void testStopWithSavepointFailingAfterSnapshotCreation() throws Exception {
         // the trigger need to be reset in case the test is run multiple times
-        CancelFailingInfiniteTestSource.cancelTriggered = false;
+        CancelFailingInfiniteTestSource.checkpointCompleteTriggered = false;
         testStopWithFailingSourceInOnePipeline(
                 new CancelFailingInfiniteTestSource(),
                 folder.newFolder(),
@@ -1190,14 +1191,15 @@ public class SavepointITCase extends TestLogger {
      * An {@link InfiniteTestSource} implementation that fails when cancel is called for the first
      * time.
      */
-    private static class CancelFailingInfiniteTestSource extends InfiniteTestSource {
+    private static class CancelFailingInfiniteTestSource extends InfiniteTestSource
+            implements CheckpointListener {
 
-        private static volatile boolean cancelTriggered = false;
+        private static volatile boolean checkpointCompleteTriggered = false;
 
         @Override
-        public void cancel() {
-            if (!cancelTriggered) {
-                cancelTriggered = true;
+        public void notifyCheckpointComplete(long checkpointId) throws Exception {
+            if (!checkpointCompleteTriggered) {
+                checkpointCompleteTriggered = true;
                 throw new RuntimeException("Expected RuntimeException after snapshot creation.");
             }
             super.cancel();
diff --git a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveSchedulerITCase.java b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveSchedulerITCase.java
index 655298f..13062e7 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveSchedulerITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/scheduling/AdaptiveSchedulerITCase.java
@@ -124,7 +124,7 @@ public class AdaptiveSchedulerITCase extends TestLogger {
     private enum StopWithSavepointTestBehavior {
         NO_FAILURE,
         FAIL_ON_CHECKPOINT,
-        FAIL_ON_STOP,
+        FAIL_ON_CHECKPOINT_COMPLETE,
         FAIL_ON_FIRST_CHECKPOINT_ONLY
     }
 
@@ -172,7 +172,7 @@ public class AdaptiveSchedulerITCase extends TestLogger {
     @Test
     public void testStopWithSavepointFailOnStop() throws Exception {
         StreamExecutionEnvironment env =
-                getEnvWithSource(StopWithSavepointTestBehavior.FAIL_ON_STOP);
+                getEnvWithSource(StopWithSavepointTestBehavior.FAIL_ON_CHECKPOINT_COMPLETE);
         env.setRestartStrategy(RestartStrategies.fixedDelayRestart(Integer.MAX_VALUE, 0L));
 
         DummySource.resetForParallelism(PARALLELISM);
@@ -257,7 +257,6 @@ public class AdaptiveSchedulerITCase extends TestLogger {
         private final StopWithSavepointTestBehavior behavior;
         private volatile boolean running = true;
         private static volatile CountDownLatch instancesRunning;
-        private volatile boolean checkpointComplete = false;
 
         public DummySource(StopWithSavepointTestBehavior behavior) {
             this.behavior = behavior;
@@ -288,9 +287,6 @@ public class AdaptiveSchedulerITCase extends TestLogger {
         @Override
         public void cancel() {
             running = false;
-            if (checkpointComplete && behavior == StopWithSavepointTestBehavior.FAIL_ON_STOP) {
-                throw new RuntimeException(behavior.name());
-            }
         }
 
         @Override
@@ -309,7 +305,9 @@ public class AdaptiveSchedulerITCase extends TestLogger {
 
         @Override
         public void notifyCheckpointComplete(long checkpointId) throws Exception {
-            checkpointComplete = true;
+            if (behavior == StopWithSavepointTestBehavior.FAIL_ON_CHECKPOINT_COMPLETE) {
+                throw new RuntimeException(behavior.name());
+            }
         }
     }