You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by dw...@apache.org on 2021/12/08 07:18:16 UTC

[flink] 01/04: [FLINK-23532] Pass a flag for draining along with EndOfData

This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 72a2471fdb08a625cbe173ef89b53db8425a14b6
Author: Dawid Wysakowicz <dw...@apache.org>
AuthorDate: Tue Nov 30 14:13:25 2021 +0100

    [FLINK-23532] Pass a flag for draining along with EndOfData
    
    For the sake of unification, we want to emit EndOfData in case of stop-with-savepoint both with and without drain. This is a preparation so that we can enclose the flag inside of EndOfData.
---
 .../flink/runtime/io/PullingAsyncDataInput.java    | 10 +++++
 .../flink/runtime/io/network/api/EndOfData.java    | 44 +++++++++++++++------
 .../network/api/serialization/EventSerializer.java | 11 +++++-
 .../network/api/writer/ResultPartitionWriter.java  |  5 ++-
 .../partition/BoundedBlockingResultPartition.java  |  4 +-
 .../partition/PipelinedResultPartition.java        |  4 +-
 .../io/network/partition/ResultPartition.java      |  2 +-
 .../partition/SortMergeResultPartition.java        |  4 +-
 .../io/network/partition/consumer/InputGate.java   | 10 +++++
 .../partition/consumer/SingleInputGate.java        |  7 ++++
 .../network/partition/consumer/UnionInputGate.java |  7 ++++
 ...bleNotifyingResultPartitionWriterDecorator.java |  4 +-
 .../runtime/taskmanager/InputGateWithMetrics.java  |  5 +++
 .../api/serialization/EventSerializerTest.java     |  3 +-
 ...cordOrEventCollectingResultPartitionWriter.java |  4 +-
 .../netty/PartitionRequestServerHandlerTest.java   |  2 +-
 .../BoundedBlockingSubpartitionWriteReadTest.java  |  2 +-
 .../partition/MockResultPartitionWriter.java       |  2 +-
 .../io/network/partition/ResultPartitionTest.java  |  4 +-
 .../partition/consumer/SingleInputGateTest.java    | 46 ++++++++++++++++++++++
 .../partition/consumer/TestInputChannel.java       |  6 ++-
 .../partition/consumer/UnionInputGateTest.java     | 43 ++++++++++++++++++++
 .../runtime/io/AbstractStreamTaskNetworkInput.java |  4 +-
 .../streaming/runtime/io/DataInputStatus.java      |  3 ++
 .../runtime/io/MultipleInputSelectionHandler.java  |  7 +++-
 .../io/checkpointing/CheckpointedInputGate.java    |  5 +++
 .../flink/streaming/runtime/tasks/StreamTask.java  |  4 +-
 .../consumer/StreamTestSingleInputGate.java        |  2 +-
 .../streaming/runtime/io/MockIndexedInputGate.java |  5 +++
 .../flink/streaming/runtime/io/MockInputGate.java  |  5 +++
 .../AlignedCheckpointsMassiveRandomTest.java       |  5 +++
 ...tStreamTaskChainedSourcesCheckpointingTest.java | 17 ++++----
 .../runtime/tasks/MultipleInputStreamTaskTest.java |  9 +++--
 .../tasks/SourceOperatorStreamTaskTest.java        | 11 +++---
 .../runtime/tasks/SourceStreamTaskTest.java        |  7 ++--
 .../runtime/tasks/SourceTaskTerminationTest.java   |  2 +-
 .../tasks/StreamTaskFinalCheckpointsTest.java      | 16 ++++----
 .../runtime/tasks/StreamTaskTestHarness.java       |  2 +-
 .../runtime/tasks/TwoInputStreamTaskTest.java      |  3 +-
 39 files changed, 269 insertions(+), 67 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
index fa87f13..73869cb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/PullingAsyncDataInput.java
@@ -18,6 +18,7 @@
 package org.apache.flink.runtime.io;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.runtime.io.network.api.EndOfData;
 
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
@@ -65,4 +66,13 @@ public interface PullingAsyncDataInput<T> extends AvailabilityProvider {
      *     point
      */
     boolean hasReceivedEndOfData();
+
+    /**
+     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
+     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
+     * should not drain the current task. See also {@code StopMode}.
+     *
+     * <p>We should check the {@link #hasReceivedEndOfData()} first.
+     */
+    boolean shouldDrainOnEndOfData();
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
index 7421f90..4e244f9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/EndOfData.java
@@ -23,6 +23,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.event.RuntimeEvent;
 
 import java.io.IOException;
+import java.util.Objects;
 
 /**
  * This event indicates there will be no more data records in a subpartition. There still might be
@@ -35,36 +36,57 @@ import java.io.IOException;
  */
 public class EndOfData extends RuntimeEvent {
 
-    /** The singleton instance of this event. */
-    public static final EndOfData INSTANCE = new EndOfData();
+    private final boolean shouldDrain;
 
     // ------------------------------------------------------------------------
 
-    // not instantiable
-    private EndOfData() {}
+    public EndOfData(boolean shouldDrain) {
+        this.shouldDrain = shouldDrain;
+    }
+
+    public boolean shouldDrain() {
+        return shouldDrain;
+    }
 
     // ------------------------------------------------------------------------
 
+    //
+    //  These methods are inherited form the generic serialization of AbstractEvent
+    //  but would require the CheckpointBarrier to be mutable. Since all serialization
+    //  for events goes through the EventSerializer class, which has special serialization
+    //  for the CheckpointBarrier, we don't need these methods
+    //
     @Override
-    public void write(DataOutputView out) throws IOException {}
+    public void write(DataOutputView out) throws IOException {
+        throw new UnsupportedOperationException("This method should never be called");
+    }
 
     @Override
-    public void read(DataInputView in) throws IOException {}
+    public void read(DataInputView in) throws IOException {
+        throw new UnsupportedOperationException("This method should never be called");
+    }
 
     // ------------------------------------------------------------------------
 
     @Override
-    public int hashCode() {
-        return 1965146684;
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        EndOfData endOfData = (EndOfData) o;
+        return shouldDrain == endOfData.shouldDrain;
     }
 
     @Override
-    public boolean equals(Object obj) {
-        return obj != null && obj.getClass() == EndOfData.class;
+    public int hashCode() {
+        return Objects.hash(shouldDrain);
     }
 
     @Override
     public String toString() {
-        return getClass().getSimpleName();
+        return "EndOfData{shouldDrain=" + shouldDrain + '}';
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
index bca2feb..46ee08d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializer.java
@@ -94,7 +94,14 @@ public class EventSerializer {
         } else if (eventClass == EndOfChannelStateEvent.class) {
             return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_CHANNEL_STATE_EVENT});
         } else if (eventClass == EndOfData.class) {
-            return ByteBuffer.wrap(new byte[] {0, 0, 0, END_OF_USER_RECORDS_EVENT});
+            return ByteBuffer.wrap(
+                    new byte[] {
+                        0,
+                        0,
+                        0,
+                        END_OF_USER_RECORDS_EVENT,
+                        ((EndOfData) event).shouldDrain() ? (byte) 1 : (byte) 0
+                    });
         } else if (eventClass == CancelCheckpointMarker.class) {
             CancelCheckpointMarker marker = (CancelCheckpointMarker) event;
 
@@ -157,7 +164,7 @@ public class EventSerializer {
             } else if (type == END_OF_CHANNEL_STATE_EVENT) {
                 return EndOfChannelStateEvent.INSTANCE;
             } else if (type == END_OF_USER_RECORDS_EVENT) {
-                return EndOfData.INSTANCE;
+                return new EndOfData(buffer.get() == 1);
             } else if (type == CANCEL_CHECKPOINT_MARKER_EVENT) {
                 long id = buffer.getLong();
                 return new CancelCheckpointMarker(id);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index bd9a2ef..ff0160a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -69,8 +69,11 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
     /**
      * Notifies the downstream tasks that this {@code ResultPartitionWriter} have emitted all the
      * user records.
+     *
+     * @param shouldDrain tells if we should flush all records or not (it is false in case of
+     *     stop-with-savepoint (--no-drain))
      */
-    void notifyEndOfData() throws IOException;
+    void notifyEndOfData(boolean shouldDrain) throws IOException;
 
     /**
      * Gets the future indicating whether all the records has been processed by the downstream
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
index a3f1372..9b6a724 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingResultPartition.java
@@ -66,9 +66,9 @@ public class BoundedBlockingResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         if (!hasNotifiedEndOfUserRecords) {
-            broadcastEvent(EndOfData.INSTANCE, false);
+            broadcastEvent(new EndOfData(shouldDrain), false);
             hasNotifiedEndOfUserRecords = true;
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
index f94d9ad..dd00bb1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedResultPartition.java
@@ -193,10 +193,10 @@ public class PipelinedResultPartition extends BufferWritingResultPartition
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(EndOfData.INSTANCE, false);
+                broadcastEvent(new EndOfData(shouldDrain), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index 1207fed..87561ec 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -200,7 +200,7 @@ public abstract class ResultPartition implements ResultPartitionWriter {
     // ------------------------------------------------------------------------
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         throw new UnsupportedOperationException();
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
index bc214ce..a8e8d31 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/SortMergeResultPartition.java
@@ -398,10 +398,10 @@ public class SortMergeResultPartition extends ResultPartition {
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         synchronized (lock) {
             if (!hasNotifiedEndOfUserRecords) {
-                broadcastEvent(EndOfData.INSTANCE, false);
+                broadcastEvent(new EndOfData(shouldDrain), false);
                 hasNotifiedEndOfUserRecords = true;
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
index 2253676..012aec6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGate.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
 import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.event.TaskEvent;
 import org.apache.flink.runtime.io.PullingAsyncDataInput;
+import org.apache.flink.runtime.io.network.api.EndOfData;
 import org.apache.flink.runtime.io.network.partition.ChannelStateHolder;
 
 import java.io.IOException;
@@ -101,6 +102,15 @@ public abstract class InputGate
     public abstract boolean hasReceivedEndOfData();
 
     /**
+     * Tells if we should drain all results in case we received {@link EndOfData} on all channels.
+     * If any of the upstream subtasks finished because of the stop-with-savepoint --no-drain, we
+     * should not drain the current task. See also {@code StopMode}.
+     *
+     * <p>We should check the {@link #hasReceivedEndOfData()} first.
+     */
+    public abstract boolean shouldDrainOnEndOfData();
+
+    /**
      * Blocking call waiting for next {@link BufferOrEvent}.
      *
      * <p>Note: It should be guaranteed that the previous returned buffer has been recycled before
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index f560ec4..ee78ff6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -208,6 +208,7 @@ public class SingleInputGate extends IndexedInputGate {
 
     private final ThroughputCalculator throughputCalculator;
     private final BufferDebloater bufferDebloater;
+    private boolean shouldDrainOnEndOfData = true;
 
     public SingleInputGate(
             String owningTaskName,
@@ -671,6 +672,11 @@ public class SingleInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return shouldDrainOnEndOfData;
+    }
+
+    @Override
     public String toString() {
         return "SingleInputGate{"
                 + "owningTaskName='"
@@ -849,6 +855,7 @@ public class SingleInputGate extends IndexedInputGate {
                 channelsWithEndOfUserRecords.set(currentChannel.getChannelIndex());
                 hasReceivedEndOfData =
                         channelsWithEndOfUserRecords.cardinality() == numberOfInputChannels;
+                shouldDrainOnEndOfData &= ((EndOfData) event).shouldDrain();
             }
         }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index ce8b5e8..4154615 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -76,6 +76,7 @@ public class UnionInputGate extends InputGate {
 
     private final Set<IndexedInputGate> inputGatesWithRemainingUserData;
 
+    private boolean shouldDrainOnEndOfData = true;
     /**
      * Gates, which notified this input gate about available data. We are using it as a FIFO queue
      * of {@link InputGate}s to avoid starvation and provide some basic fairness.
@@ -185,6 +186,11 @@ public class UnionInputGate extends InputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return shouldDrainOnEndOfData;
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() throws IOException, InterruptedException {
         return getNextBufferOrEvent(true);
     }
@@ -289,6 +295,7 @@ public class UnionInputGate extends InputGate {
                 && bufferOrEvent.getEvent().getClass() == EndOfData.class
                 && inputGate.hasReceivedEndOfData()) {
 
+            shouldDrainOnEndOfData &= inputGate.shouldDrainOnEndOfData();
             if (!inputGatesWithRemainingUserData.remove(inputGate)) {
                 throw new IllegalStateException(
                         "Couldn't find input gate in set of remaining input gates.");
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
index f498ada..a7d425c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
@@ -149,8 +149,8 @@ public class ConsumableNotifyingResultPartitionWriterDecorator {
         }
 
         @Override
-        public void notifyEndOfData() throws IOException {
-            partitionWriter.notifyEndOfData();
+        public void notifyEndOfData(boolean shouldDrain) throws IOException {
+            partitionWriter.notifyEndOfData(shouldDrain);
         }
 
         @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
index 2a09d5f..dae257d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/InputGateWithMetrics.java
@@ -101,6 +101,11 @@ public class InputGateWithMetrics extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return inputGate.shouldDrainOnEndOfData();
+    }
+
+    @Override
     public void setup() throws IOException {
         inputGate.setup();
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
index 5445a02..6a2ef20 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/serialization/EventSerializerTest.java
@@ -49,7 +49,8 @@ public class EventSerializerTest {
     private final AbstractEvent[] events = {
         EndOfPartitionEvent.INSTANCE,
         EndOfSuperstepEvent.INSTANCE,
-        EndOfData.INSTANCE,
+        new EndOfData(true),
+        new EndOfData(false),
         new CheckpointBarrier(
                 1678L,
                 4623784L,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
index b95febb..bd2b667 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordOrEventCollectingResultPartitionWriter.java
@@ -89,9 +89,9 @@ public class RecordOrEventCollectingResultPartitionWriter<T>
     }
 
     @Override
-    public void notifyEndOfData() throws IOException {
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {
         if (collectNetworkEvents) {
-            broadcastEvent(EndOfData.INSTANCE, false);
+            broadcastEvent(new EndOfData(shouldDrain), false);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
index b0fcf42..9e1973f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/netty/PartitionRequestServerHandlerTest.java
@@ -128,7 +128,7 @@ public class PartitionRequestServerHandlerTest extends TestLogger {
         partitionRequestQueue.notifyReaderCreated(viewReader);
 
         // Write the message to acknowledge all records are processed to server
-        resultPartition.notifyEndOfData();
+        resultPartition.notifyEndOfData(true);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 resultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
index ed9b181..855ff82 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/BoundedBlockingSubpartitionWriteReadTest.java
@@ -256,7 +256,7 @@ public class BoundedBlockingSubpartitionWriteReadTest {
 
     private void writeEndOfData(BoundedBlockingSubpartition subpartition) throws IOException {
         try (BufferConsumer eventBufferConsumer =
-                EventSerializer.toBufferConsumer(EndOfData.INSTANCE, false)) {
+                EventSerializer.toBufferConsumer(new EndOfData(true), false)) {
             // Retain the buffer so that it can be recycled by each channel of targetPartition
             subpartition.add(eventBufferConsumer.copy(), 0);
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
index dc85d1c..6db08bd 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
@@ -61,7 +61,7 @@ public class MockResultPartitionWriter implements ResultPartitionWriter {
     public void broadcastEvent(AbstractEvent event, boolean isPriorityEvent) throws IOException {}
 
     @Override
-    public void notifyEndOfData() throws IOException {}
+    public void notifyEndOfData(boolean shouldDrain) throws IOException {}
 
     @Override
     public CompletableFuture<Void> getAllDataProcessedFuture() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index e522bd4..06ac632 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -625,7 +625,7 @@ public class ResultPartitionTest {
         BufferWritingResultPartition bufferWritingResultPartition =
                 createResultPartition(ResultPartitionType.PIPELINED_BOUNDED);
 
-        bufferWritingResultPartition.notifyEndOfData();
+        bufferWritingResultPartition.notifyEndOfData(true);
         CompletableFuture<Void> allRecordsProcessedFuture =
                 bufferWritingResultPartition.getAllDataProcessedFuture();
         assertFalse(allRecordsProcessedFuture.isDone());
@@ -634,7 +634,7 @@ public class ResultPartitionTest {
             Buffer nextBuffer = ((PipelinedSubpartition) resultSubpartition).pollBuffer().buffer();
             assertFalse(nextBuffer.isBuffer());
             assertEquals(
-                    EndOfData.INSTANCE,
+                    new EndOfData(true),
                     EventSerializer.fromBuffer(nextBuffer, getClass().getClassLoader()));
         }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index c543014..1ea926ac 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -254,6 +254,52 @@ public class SingleInputGateTest extends InputGateTestBase {
         }
     }
 
+    @Test
+    public void testDrainFlagComputation() throws Exception {
+        // Setup
+        final SingleInputGate inputGate1 = createInputGate();
+        final SingleInputGate inputGate2 = createInputGate();
+
+        final TestInputChannel[] inputChannels1 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate1, 0), new TestInputChannel(inputGate1, 1)
+                };
+        inputGate1.setInputChannels(inputChannels1);
+        final TestInputChannel[] inputChannels2 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate2, 0), new TestInputChannel(inputGate2, 1)
+                };
+        inputGate2.setInputChannels(inputChannels2);
+
+        // Test
+        inputChannels1[1].readEndOfData(true);
+        inputChannels1[0].readEndOfData(false);
+
+        inputChannels2[1].readEndOfData(true);
+        inputChannels2[0].readEndOfData(true);
+
+        inputGate1.notifyChannelNonEmpty(inputChannels1[0]);
+        inputGate1.notifyChannelNonEmpty(inputChannels1[1]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[0]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[1]);
+
+        verifyBufferOrEvent(inputGate1, false, 0, true);
+        // we have received EndOfData on a single channel only
+        assertFalse(inputGate1.hasReceivedEndOfData());
+        verifyBufferOrEvent(inputGate1, false, 1, true);
+        assertTrue(inputGate1.hasReceivedEndOfData());
+        // one of the channels said we should not drain
+        assertFalse(inputGate1.shouldDrainOnEndOfData());
+
+        verifyBufferOrEvent(inputGate2, false, 0, true);
+        // we have received EndOfData on a single channel only
+        assertFalse(inputGate2.hasReceivedEndOfData());
+        verifyBufferOrEvent(inputGate2, false, 1, true);
+        assertTrue(inputGate2.hasReceivedEndOfData());
+        // both channels said we should drain
+        assertTrue(inputGate2.shouldDrainOnEndOfData());
+    }
+
     /**
      * Tests that the compressed buffer will be decompressed after calling {@link
      * SingleInputGate#getNext()}.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
index c89b470..3011d3c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/TestInputChannel.java
@@ -107,9 +107,13 @@ public class TestInputChannel extends InputChannel {
     }
 
     TestInputChannel readEndOfData() throws IOException {
+        return readEndOfData(true);
+    }
+
+    TestInputChannel readEndOfData(boolean shouldDrain) throws IOException {
         addBufferAndAvailability(
                 new BufferAndAvailability(
-                        EventSerializer.toBuffer(EndOfData.INSTANCE, false),
+                        EventSerializer.toBuffer(new EndOfData(shouldDrain), false),
                         Buffer.DataType.EVENT_BUFFER,
                         0,
                         sequenceNumber++));
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
index b5bd5e6..c9ca40e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGateTest.java
@@ -133,6 +133,49 @@ public class UnionInputGateTest extends InputGateTestBase {
     }
 
     @Test
+    public void testDrainFlagComputation() throws Exception {
+        // Setup
+        final SingleInputGate inputGate1 = createInputGate();
+        final SingleInputGate inputGate2 = createInputGate();
+
+        final TestInputChannel[] inputChannels1 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate1, 0), new TestInputChannel(inputGate1, 1)
+                };
+        inputGate1.setInputChannels(inputChannels1);
+        final TestInputChannel[] inputChannels2 =
+                new TestInputChannel[] {
+                    new TestInputChannel(inputGate2, 0), new TestInputChannel(inputGate2, 1)
+                };
+        inputGate2.setInputChannels(inputChannels2);
+
+        // Test
+        inputChannels1[1].readEndOfData(true);
+        inputChannels1[0].readEndOfData(false);
+
+        inputChannels2[1].readEndOfData(true);
+        inputChannels2[0].readEndOfData(true);
+
+        final UnionInputGate unionInputGate = new UnionInputGate(inputGate1, inputGate2);
+
+        inputGate1.notifyChannelNonEmpty(inputChannels1[0]);
+        inputGate1.notifyChannelNonEmpty(inputChannels1[1]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[0]);
+        inputGate2.notifyChannelNonEmpty(inputChannels2[1]);
+
+        verifyBufferOrEvent(unionInputGate, false, 0, true);
+        verifyBufferOrEvent(unionInputGate, false, 2, true);
+        // we have received EndOfData on a single input only
+        assertFalse(unionInputGate.hasReceivedEndOfData());
+
+        verifyBufferOrEvent(unionInputGate, false, 1, true);
+        verifyBufferOrEvent(unionInputGate, false, 3, true);
+        // both channels received EndOfData, one channel said we should not drain
+        assertTrue(unionInputGate.hasReceivedEndOfData());
+        assertFalse(unionInputGate.shouldDrainOnEndOfData());
+    }
+
+    @Test
     public void testIsAvailable() throws Exception {
         final SingleInputGate inputGate1 = createInputGate(1);
         TestInputChannel inputChannel1 = new TestInputChannel(inputGate1, 0);
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
index 8e84865..d60c824 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/AbstractStreamTaskNetworkInput.java
@@ -152,7 +152,9 @@ public abstract class AbstractStreamTaskNetworkInput<
         final AbstractEvent event = bufferOrEvent.getEvent();
         if (event.getClass() == EndOfData.class) {
             if (checkpointedInputGate.hasReceivedEndOfData()) {
-                return DataInputStatus.END_OF_DATA;
+                return checkpointedInputGate.shouldDrainOnEndOfData()
+                        ? DataInputStatus.END_OF_DATA
+                        : DataInputStatus.STOPPED;
             }
         } else if (event.getClass() == EndOfPartitionEvent.class) {
             // release the record deserializer immediately,
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
index 8fe97ff..d1e45fa 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/DataInputStatus.java
@@ -44,6 +44,9 @@ public enum DataInputStatus {
     /** Indicator that all persisted data of the data exchange has been successfully restored. */
     END_OF_RECOVERY,
 
+    /** Indicator that the input was stopped because of stop-with-savepoint without drain. */
+    STOPPED,
+
     /** Indicator that the input has reached the end of data. */
     END_OF_DATA,
 
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
index 31be2b7..3405d7a 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/MultipleInputSelectionHandler.java
@@ -49,6 +49,8 @@ public class MultipleInputSelectionHandler {
 
     private long dataFinishedButNotPartition;
 
+    private boolean drainOnEndOfData = true;
+
     private enum OperatingMode {
         NO_INPUT_SELECTABLE,
         INPUT_SELECTABLE_PRESENT_NO_DATA_INPUTS_FINISHED,
@@ -91,6 +93,9 @@ public class MultipleInputSelectionHandler {
             case NOTHING_AVAILABLE:
                 availableInputsMask = unsetBitMask(availableInputsMask, inputIndex);
                 break;
+            case STOPPED:
+                this.drainOnEndOfData = false;
+                // fall through
             case END_OF_DATA:
                 dataFinishedButNotPartition = setBitMask(dataFinishedButNotPartition, inputIndex);
                 updateModeOnEndOfData();
@@ -126,7 +131,7 @@ public class MultipleInputSelectionHandler {
 
         if (updatedStatus == DataInputStatus.END_OF_DATA
                 && this.operatingMode == OperatingMode.ALL_DATA_INPUTS_FINISHED) {
-            return DataInputStatus.END_OF_DATA;
+            return drainOnEndOfData ? DataInputStatus.END_OF_DATA : DataInputStatus.STOPPED;
         }
 
         if (isAnyInputAvailable()) {
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
index 048d446..b497e4d 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/io/checkpointing/CheckpointedInputGate.java
@@ -227,6 +227,11 @@ public class CheckpointedInputGate implements PullingAsyncDataInput<BufferOrEven
         return inputGate.hasReceivedEndOfData();
     }
 
+    @Override
+    public boolean shouldDrainOnEndOfData() {
+        return inputGate.shouldDrainOnEndOfData();
+    }
+
     /**
      * Cleans up all internally held resources.
      *
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index 2003c7b..06c6874 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -526,6 +526,8 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
                 break;
             case END_OF_RECOVERY:
                 throw new IllegalStateException("We should not receive this event here.");
+            case STOPPED:
+                throw new UnsupportedOperationException("Not supported yet");
             case END_OF_DATA:
                 endData();
                 return;
@@ -575,7 +577,7 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
         this.finishedOperators = true;
 
         for (ResultPartitionWriter partitionWriter : getEnvironment().getAllWriters()) {
-            partitionWriter.notifyEndOfData();
+            partitionWriter.notifyEndOfData(true);
         }
 
         this.endOfDataReceived = true;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
index cc29517..32b6e96 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/StreamTestSingleInputGate.java
@@ -124,7 +124,7 @@ public class StreamTestSingleInputGate<T> {
                         } else if (input != null && input.isDataEnd()) {
                             return Optional.of(
                                     new BufferAndAvailability(
-                                            EventSerializer.toBuffer(EndOfData.INSTANCE, false),
+                                            EventSerializer.toBuffer(new EndOfData(true), false),
                                             nextType,
                                             0,
                                             0));
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
index d7e7ae1..8bbe8fa 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockIndexedInputGate.java
@@ -101,6 +101,11 @@ public class MockIndexedInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        return false;
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() {
         throw new UnsupportedOperationException();
     }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
index 43c4b97..993361e 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/MockInputGate.java
@@ -119,6 +119,11 @@ public class MockInputGate extends IndexedInputGate {
     }
 
     @Override
+    public boolean shouldDrainOnEndOfData() {
+        throw new UnsupportedOperationException("Not implemented yet");
+    }
+
+    @Override
     public Optional<BufferOrEvent> getNext() {
         BufferOrEvent next = bufferOrEvents.poll();
         if (!finishAfterLastBuffer && bufferOrEvents.isEmpty()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
index f35bd70..903ec71 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/checkpointing/AlignedCheckpointsMassiveRandomTest.java
@@ -173,6 +173,11 @@ public class AlignedCheckpointsMassiveRandomTest {
         }
 
         @Override
+        public boolean shouldDrainOnEndOfData() {
+            return false;
+        }
+
+        @Override
         public InputChannel getChannel(int channelIndex) {
             throw new UnsupportedOperationException();
         }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
index 9ceeae7..d433b69 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskChainedSourcesCheckpointingTest.java
@@ -281,10 +281,10 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             CheckpointBarrier barrier = createStopWithSavepointDrainBarrier();
 
             testHarness.processElement(new StreamRecord<>("44", TimestampAssigner.NO_TIMESTAMP), 0);
-            testHarness.processEvent(EndOfData.INSTANCE, 0);
+            testHarness.processEvent(new EndOfData(true), 0);
             testHarness.processEvent(barrier, 0);
             testHarness.processElement(new StreamRecord<>(47d, TimestampAssigner.NO_TIMESTAMP), 1);
-            testHarness.processEvent(EndOfData.INSTANCE, 1);
+            testHarness.processEvent(new EndOfData(true), 1);
             testHarness.processEvent(barrier, 1);
 
             addSourceRecords(testHarness, 1, Boundedness.CONTINUOUS_UNBOUNDED, 1, 2);
@@ -311,7 +311,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                     containsInAnyOrder(expectedOutput.toArray()));
             assertThat(
                     actualOutput.subList(actualOutput.size() - 3, actualOutput.size()),
-                    contains(new StreamRecord<>("FINISH"), EndOfData.INSTANCE, barrier));
+                    contains(new StreamRecord<>("FINISH"), new EndOfData(true), barrier));
         }
     }
 
@@ -435,8 +435,8 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                 testHarness.processAll();
 
                 // The checkpoint 2 would be aligned after received all the EndOfPartitionEvent.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
-                testHarness.processEvent(EndOfData.INSTANCE, 1, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
+                testHarness.processEvent(new EndOfData(true), 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.getTaskStateManager().getWaitForReportLatch().await();
@@ -493,8 +493,9 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .addSourceInput(
@@ -522,7 +523,7 @@ public class MultipleInputStreamTaskChainedSourcesCheckpointingTest {
             testHarness.processElement(Watermark.MAX_WATERMARK);
             assertThat(output, is(empty()));
             testHarness.waitForTaskCompletion();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             for (StreamOperatorWrapper<?, ?> wrapper :
                     testHarness.getStreamTask().operatorChain.getAllOperators()) {
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
index f8eb733..c31dd3d 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/MultipleInputStreamTaskTest.java
@@ -989,15 +989,15 @@ public class MultipleInputStreamTaskTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
                 assertEquals(4, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after all the inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 1, 0);
-                testHarness.processEvent(EndOfData.INSTANCE, 2, 0);
+                testHarness.processEvent(new EndOfData(true), 1, 0);
+                testHarness.processEvent(new EndOfData(true), 2, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 1, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 2, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -1058,7 +1058,8 @@ public class MultipleInputStreamTaskTest {
             testHarness.processElement(Watermark.MAX_WATERMARK, 2);
             testHarness.waitForTaskCompletion();
             assertThat(
-                    testHarness.getOutput(), contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+                    testHarness.getOutput(),
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
index ffe4c15..d4fa9be 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceOperatorStreamTaskTest.java
@@ -129,7 +129,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(EndOfData.INSTANCE);
+            expectedOutput.add(new EndOfData(true));
             expectedOutput.add(
                     new CheckpointBarrier(checkpointId, checkpointId, checkpointOptions));
 
@@ -145,7 +145,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             Queue<Object> expectedOutput = new LinkedList<>();
             expectedOutput.add(Watermark.MAX_WATERMARK);
-            expectedOutput.add(EndOfData.INSTANCE);
+            expectedOutput.add(new EndOfData(true));
             assertThat(testHarness.getOutput().toArray(), equalTo(expectedOutput.toArray()));
         }
     }
@@ -203,8 +203,9 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .setupOperatorChain(sourceOperatorFactory)
@@ -214,7 +215,7 @@ public class SourceOperatorStreamTaskTest extends SourceStreamTaskTestBase {
 
             testHarness.getStreamTask().invoke();
             testHarness.processAll();
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             LifeCycleMonitorSourceReader sourceReader =
                     (LifeCycleMonitorSourceReader)
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
index 5fbc65c..6553279 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceStreamTaskTest.java
@@ -720,8 +720,9 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
                                         output,
                                         new StreamElementSerializer<>(IntSerializer.INSTANCE)) {
                                     @Override
-                                    public void notifyEndOfData() throws IOException {
-                                        broadcastEvent(EndOfData.INSTANCE, false);
+                                    public void notifyEndOfData(boolean shouldDrain)
+                                            throws IOException {
+                                        broadcastEvent(new EndOfData(shouldDrain), false);
                                     }
                                 })
                         .setupOperatorChain(new StreamSource<>(testSource))
@@ -732,7 +733,7 @@ public class SourceStreamTaskTest extends SourceStreamTaskTestBase {
             harness.processAll();
             harness.streamTask.getCompletionFuture().get();
 
-            assertThat(output, contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+            assertThat(output, contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
 
             LifeCycleMonitorSource source =
                     (LifeCycleMonitorSource)
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
index 8cf461c..5a67e86 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/SourceTaskTerminationTest.java
@@ -107,7 +107,7 @@ public class SourceTaskTerminationTest extends TestLogger {
                 // if we are in TERMINATE mode, we expect the source task
                 // to emit MAX_WM before the SYNC_SAVEPOINT barrier.
                 verifyWatermark(srcTaskTestHarness.getOutput(), Watermark.MAX_WATERMARK);
-                verifyEvent(srcTaskTestHarness.getOutput(), EndOfData.INSTANCE);
+                verifyEvent(srcTaskTestHarness.getOutput(), new EndOfData(true));
             }
 
             verifyCheckpointBarrier(srcTaskTestHarness.getOutput(), syncSavepointId);
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
index efd0ea1..88087d5 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskFinalCheckpointsTest.java
@@ -155,7 +155,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertEquals(2, testHarness.getTaskStateManager().getReportedCheckpointId());
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -163,8 +163,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 1);
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 2);
+                testHarness.processEvent(new EndOfData(true), 0, 1);
+                testHarness.processEvent(new EndOfData(true), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, lastCheckpointId);
@@ -664,7 +664,7 @@ public class StreamTaskFinalCheckpointsTest {
                 assertArrayEquals(new int[] {0, 0, 0}, resumedCount);
 
                 // Tests triggering checkpoint after some inputs have received EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4, checkpointOptions);
                 processMailTillCheckpointSucceeds(testHarness, checkpointFuture);
@@ -673,8 +673,8 @@ public class StreamTaskFinalCheckpointsTest {
 
                 // Tests triggering checkpoint after received all the inputs have received
                 // EndOfPartition.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 1);
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 2);
+                testHarness.processEvent(new EndOfData(true), 0, 1);
+                testHarness.processEvent(new EndOfData(true), 0, 2);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 1);
                 testHarness.processEvent(EndOfPartitionEvent.INSTANCE, 0, 2);
                 checkpointFuture = triggerCheckpoint(testHarness, 6, checkpointOptions);
@@ -759,7 +759,7 @@ public class StreamTaskFinalCheckpointsTest {
                 // The checkpoint is added to the mailbox and will be processed in the
                 // mailbox loop after call operators' finish method in the afterInvoke()
                 // method.
-                testHarness.processEvent(EndOfData.INSTANCE, 0, 0);
+                testHarness.processEvent(new EndOfData(true), 0, 0);
                 checkpointFuture = triggerCheckpoint(testHarness, 4);
                 checkpointFuture.thenAccept(
                         (ignored) -> {
@@ -947,7 +947,7 @@ public class StreamTaskFinalCheckpointsTest {
                                     checkpointMetaData.getTimestamp(),
                                     checkpointOptions),
                             Watermark.MAX_WATERMARK,
-                            EndOfData.INSTANCE));
+                            new EndOfData(true)));
         }
     }
 
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index 2bae6ed..a95b17c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -492,7 +492,7 @@ public class StreamTaskTestHarness<OUT> {
 
     public void endInput(int gateIndex, int channelIndex, boolean emitEndOfData) {
         if (emitEndOfData) {
-            inputGates[gateIndex].sendEvent(EndOfData.INSTANCE, channelIndex);
+            inputGates[gateIndex].sendEvent(new EndOfData(true), channelIndex);
         }
         inputGates[gateIndex].sendEvent(EndOfPartitionEvent.INSTANCE, channelIndex);
     }
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
index 2686ad1..ab8304f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/TwoInputStreamTaskTest.java
@@ -542,7 +542,8 @@ public class TwoInputStreamTaskTest {
             testHarness.processElement(Watermark.MAX_WATERMARK, 1);
             testHarness.waitForTaskCompletion();
             assertThat(
-                    testHarness.getOutput(), contains(Watermark.MAX_WATERMARK, EndOfData.INSTANCE));
+                    testHarness.getOutput(),
+                    contains(Watermark.MAX_WATERMARK, new EndOfData(true)));
         }
     }