You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/08/08 14:58:42 UTC

[flink] 02/02: [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 72405361610f1576e0f57ea4e957fafbbbaf0910
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Mon Jul 25 16:03:28 2022 +0800

    [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side
    
    This closes #20350.
---
 .../optimizer/plantranslate/JobGraphGenerator.java |   5 +-
 .../TaskDeploymentDescriptorFactory.java           |   9 +-
 .../executiongraph/DefaultExecutionGraph.java      |  15 +-
 .../flink/runtime/executiongraph/EdgeManager.java  |  24 +++-
 .../executiongraph/EdgeManagerBuildUtil.java       |  21 +--
 .../flink/runtime/executiongraph/Execution.java    |  61 +++++----
 .../runtime/executiongraph/IntermediateResult.java |  70 ++++++++--
 .../IntermediateResultPartition.java               |  58 +++++---
 .../RestartPipelinedRegionFailoverStrategy.java    |  12 +-
 .../SchedulingPipelinedRegionComputeUtil.java      |  32 ++---
 .../io/network/NettyShuffleEnvironment.java        |  36 +++--
 .../runtime/jobgraph/IntermediateDataSet.java      |  40 ++++--
 .../org/apache/flink/runtime/jobgraph/JobEdge.java |  48 ++-----
 .../apache/flink/runtime/jobgraph/JobGraph.java    |   5 +-
 .../apache/flink/runtime/jobgraph/JobVertex.java   |  41 +++---
 .../SsgNetworkMemoryCalculationUtils.java          |  30 ++--
 .../adapter/DefaultExecutionTopology.java          |  10 +-
 .../scheduler/adapter/DefaultResultPartition.java  |  11 +-
 .../scheduler/strategy/ConsumedPartitionGroup.java |  19 ++-
 .../strategy/SchedulingResultPartition.java        |   7 +-
 .../strategy/VertexwiseSchedulingStrategy.java     |  10 +-
 .../BlockingResultPartitionReleaseTest.java        | 151 +++++++++++++++++++++
 .../DefaultExecutionGraphConstructionTest.java     |  41 ++++++
 .../executiongraph/EdgeManagerBuildUtilTest.java   |   2 +-
 .../runtime/executiongraph/EdgeManagerTest.java    |  88 +++++++++---
 .../executiongraph/ExecutionGraphTestUtils.java    |  19 +++
 .../executiongraph/ExecutionJobVertexTest.java     |   2 +-
 .../IntermediateResultPartitionTest.java           |  62 ++++++++-
 .../RemoveCachedShuffleDescriptorTest.java         | 133 ++++--------------
 .../partition/NoOpJobMasterPartitionTracker.java   |   5 +-
 .../flink/runtime/jobgraph/JobTaskVertexTest.java  |  43 +++++-
 .../jobmaster/JobIntermediateDatasetReuseTest.java |   3 +-
 .../runtime/scheduler/SchedulerTestingUtils.java   | 137 +++++++++++++++++++
 .../adapter/DefaultExecutionTopologyTest.java      |  19 ++-
 .../adapter/DefaultExecutionVertexTest.java        |   1 +
 .../adapter/DefaultResultPartitionTest.java        |  20 ++-
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |   2 +-
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  12 +-
 .../strategy/TestingSchedulingExecutionVertex.java |   7 +-
 .../strategy/TestingSchedulingResultPartition.java |   9 +-
 .../strategy/TestingSchedulingTopology.java        |   1 +
 .../api/graph/StreamingJobGraphGenerator.java      |   7 +-
 .../validation/TestJobDataFlowValidator.java       |  34 ++---
 43 files changed, 958 insertions(+), 404 deletions(-)

diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
index 0da624f1e7f..887da47ed0c 100644
--- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
+++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
@@ -1246,7 +1246,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
                 predecessorVertex != null,
                 "Bug: Chained task has not been assigned its containing vertex when connecting.");
 
-        predecessorVertex.createAndAddResultDataSet(
+        predecessorVertex.getOrCreateResultDataSet(
                 // use specified intermediateDataSetID
                 new IntermediateDataSetID(
                         ((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()),
@@ -1326,7 +1326,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
 
         JobEdge edge =
                 targetVertex.connectNewDataSetAsInput(
-                        sourceVertex, distributionPattern, resultType);
+                        sourceVertex, distributionPattern, resultType, isBroadcast);
 
         // -------------- configure the source task's ship strategy strategies in task config
         // --------------
@@ -1403,7 +1403,6 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
                 channel.getTempMode() == TempMode.NONE ? null : channel.getTempMode().toString();
 
         edge.setShipStrategyName(shipStrategy);
-        edge.setBroadcast(isBroadcast);
         edge.setForward(channel.getShipStrategy() == ShipStrategyType.FORWARD);
         edge.setPreProcessingOperationName(localStrategy);
         edge.setOperatorLevelCachingDescription(caching);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index a2c45ed0c47..ba10dd59de8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -133,7 +133,9 @@ public class TaskDeploymentDescriptorFactory {
             IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
             SubpartitionIndexRange consumedSubpartitionRange =
                     computeConsumedSubpartitionRange(
-                            resultPartition, executionId.getSubtaskIndex());
+                            consumedPartitionGroup.getNumConsumers(),
+                            resultPartition,
+                            executionId.getSubtaskIndex());
 
             IntermediateDataSetID resultId = consumedIntermediateResult.getId();
             ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
@@ -164,8 +166,9 @@ public class TaskDeploymentDescriptorFactory {
     }
 
     public static SubpartitionIndexRange computeConsumedSubpartitionRange(
-            IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
-        int numConsumers = resultPartition.getConsumerVertexGroup().size();
+            int numConsumers,
+            IntermediateResultPartition resultPartition,
+            int consumerSubtaskIndex) {
         int consumerIndex = consumerSubtaskIndex % numConsumers;
         IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
         int numSubpartitions = resultPartition.getNumberOfSubpartitions();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 2f160f14605..ce1f18c5504 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -1399,6 +1399,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
             final List<ConsumedPartitionGroup> releasablePartitionGroups) {
 
         if (releasablePartitionGroups.size() > 0) {
+            final List<ResultPartitionID> releasablePartitionIds = new ArrayList<>();
 
             // Remove the cache of ShuffleDescriptors when ConsumedPartitionGroups are released
             for (ConsumedPartitionGroup releasablePartitionGroup : releasablePartitionGroups) {
@@ -1406,15 +1407,17 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
                         checkNotNull(
                                 intermediateResults.get(
                                         releasablePartitionGroup.getIntermediateDataSetID()));
+                for (IntermediateResultPartitionID partitionId : releasablePartitionGroup) {
+                    IntermediateResultPartition partition =
+                            totalResult.getPartitionById(partitionId);
+                    partition.markPartitionGroupReleasable(releasablePartitionGroup);
+                    if (partition.canBeReleased()) {
+                        releasablePartitionIds.add(createResultPartitionId(partitionId));
+                    }
+                }
                 totalResult.clearCachedInformationForPartitionGroup(releasablePartitionGroup);
             }
 
-            final List<ResultPartitionID> releasablePartitionIds =
-                    releasablePartitionGroups.stream()
-                            .flatMap(IterableUtils::toStream)
-                            .map(this::createResultPartitionId)
-                            .collect(Collectors.toList());
-
             partitionTracker.stopTrackingAndReleasePartitions(releasablePartitionIds);
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
index 8efc25a913b..1a437e16030 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
@@ -30,12 +30,11 @@ import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
-import static org.apache.flink.util.Preconditions.checkState;
 
 /** Class that manages all the connections between tasks. */
 public class EdgeManager {
 
-    private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> partitionConsumers =
+    private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>> partitionConsumers =
             new HashMap<>();
 
     private final Map<ExecutionVertexID, List<ConsumedPartitionGroup>> vertexConsumedPartitions =
@@ -50,9 +49,9 @@ public class EdgeManager {
 
         checkNotNull(consumerVertexGroup);
 
-        checkState(
-                partitionConsumers.putIfAbsent(resultPartitionId, consumerVertexGroup) == null,
-                "Currently one partition can have at most one consumer group");
+        List<ConsumerVertexGroup> groups =
+                getConsumerVertexGroupsForPartitionInternal(resultPartitionId);
+        groups.add(consumerVertexGroup);
     }
 
     public void connectVertexWithConsumedPartitionGroup(
@@ -66,14 +65,20 @@ public class EdgeManager {
         consumedPartitions.add(consumedPartitionGroup);
     }
 
+    private List<ConsumerVertexGroup> getConsumerVertexGroupsForPartitionInternal(
+            IntermediateResultPartitionID resultPartitionId) {
+        return partitionConsumers.computeIfAbsent(resultPartitionId, id -> new ArrayList<>());
+    }
+
     private List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertexInternal(
             ExecutionVertexID executionVertexId) {
         return vertexConsumedPartitions.computeIfAbsent(executionVertexId, id -> new ArrayList<>());
     }
 
-    public ConsumerVertexGroup getConsumerVertexGroupForPartition(
+    public List<ConsumerVertexGroup> getConsumerVertexGroupsForPartition(
             IntermediateResultPartitionID resultPartitionId) {
-        return partitionConsumers.get(resultPartitionId);
+        return Collections.unmodifiableList(
+                getConsumerVertexGroupsForPartitionInternal(resultPartitionId));
     }
 
     public List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertex(
@@ -100,4 +105,9 @@ public class EdgeManager {
         return Collections.unmodifiableList(
                 getConsumedPartitionGroupsByIdInternal(resultPartitionId));
     }
+
+    public int getNumberOfConsumedPartitionGroupsById(
+            IntermediateResultPartitionID resultPartitionId) {
+        return getConsumedPartitionGroupsByIdInternal(resultPartitionId).size();
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index 3aa6e59de0d..8ac55b7b7a2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -89,7 +89,7 @@ public class EdgeManagerBuildUtil {
                         .collect(Collectors.toList());
         ConsumedPartitionGroup consumedPartitionGroup =
                 createAndRegisterConsumedPartitionGroupToEdgeManager(
-                        consumedPartitions, intermediateResult);
+                        taskVertices.length, consumedPartitions, intermediateResult);
         for (ExecutionVertex ev : taskVertices) {
             ev.addConsumedPartitionGroup(consumedPartitionGroup);
         }
@@ -122,7 +122,9 @@ public class EdgeManagerBuildUtil {
 
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                partition.getPartitionId(), intermediateResult);
+                                consumerVertexGroup.size(),
+                                partition.getPartitionId(),
+                                intermediateResult);
                 executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
             }
         } else if (sourceCount > targetCount) {
@@ -147,20 +149,19 @@ public class EdgeManagerBuildUtil {
 
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                consumedPartitions, intermediateResult);
+                                consumerVertexGroup.size(), consumedPartitions, intermediateResult);
                 executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
             }
         } else {
             for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {
+                int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
+                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
 
                 IntermediateResultPartition partition =
                         intermediateResult.getPartitions()[partitionNum];
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                partition.getPartitionId(), intermediateResult);
-
-                int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
-                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
+                                end - start, partition.getPartitionId(), intermediateResult);
 
                 List<ExecutionVertexID> consumers = new ArrayList<>(end - start);
 
@@ -179,21 +180,23 @@ public class EdgeManagerBuildUtil {
     }
 
     private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+            int numConsumers,
             IntermediateResultPartitionID consumedPartitionId,
             IntermediateResult intermediateResult) {
         ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromSinglePartition(
-                        consumedPartitionId, intermediateResult.getResultType());
+                        numConsumers, consumedPartitionId, intermediateResult.getResultType());
         registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
         return consumedPartitionGroup;
     }
 
     private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+            int numConsumers,
             List<IntermediateResultPartitionID> consumedPartitions,
             IntermediateResult intermediateResult) {
         ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromMultiplePartitions(
-                        consumedPartitions, intermediateResult.getResultType());
+                        numConsumers, consumedPartitions, intermediateResult.getResultType());
         registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
         return consumedPartitionGroup;
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 657effb66f6..f6de9059869 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -65,6 +65,7 @@ import javax.annotation.Nullable;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -497,10 +498,7 @@ public class Execution
     }
 
     private static int getPartitionMaxParallelism(IntermediateResultPartition partition) {
-        return partition
-                .getIntermediateResult()
-                .getConsumerExecutionJobVertex()
-                .getMaxParallelism();
+        return partition.getIntermediateResult().getConsumersMaxParallelism();
     }
 
     /**
@@ -718,31 +716,40 @@ public class Execution
     }
 
     private void updatePartitionConsumers(final IntermediateResultPartition partition) {
-        final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                partition.getConsumerVertexGroupOptional();
-        if (!consumerVertexGroup.isPresent()) {
+        final List<ConsumerVertexGroup> consumerVertexGroups = partition.getConsumerVertexGroups();
+        if (consumerVertexGroups.isEmpty()) {
             return;
         }
-        for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
-            final ExecutionVertex consumerVertex =
-                    vertex.getExecutionGraphAccessor().getExecutionVertexOrThrow(consumerVertexId);
-            final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
-            final ExecutionState consumerState = consumer.getState();
-
-            // ----------------------------------------------------------------
-            // Consumer is recovering or running => send update message now
-            // Consumer is deploying => cache the partition info which would be
-            // sent after switching to running
-            // ----------------------------------------------------------------
-            if (consumerState == DEPLOYING
-                    || consumerState == RUNNING
-                    || consumerState == INITIALIZING) {
-                final PartitionInfo partitionInfo = createPartitionInfo(partition);
-
-                if (consumerState == DEPLOYING) {
-                    consumerVertex.cachePartitionInfo(partitionInfo);
-                } else {
-                    consumer.sendUpdatePartitionInfoRpcCall(Collections.singleton(partitionInfo));
+        final Set<ExecutionVertexID> updatedVertices = new HashSet<>();
+        for (ConsumerVertexGroup consumerVertexGroup : consumerVertexGroups) {
+            for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+                if (updatedVertices.contains(consumerVertexId)) {
+                    continue;
+                }
+
+                final ExecutionVertex consumerVertex =
+                        vertex.getExecutionGraphAccessor()
+                                .getExecutionVertexOrThrow(consumerVertexId);
+                final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
+                final ExecutionState consumerState = consumer.getState();
+
+                // ----------------------------------------------------------------
+                // Consumer is recovering or running => send update message now
+                // Consumer is deploying => cache the partition info which would be
+                // sent after switching to running
+                // ----------------------------------------------------------------
+                if (consumerState == DEPLOYING
+                        || consumerState == RUNNING
+                        || consumerState == INITIALIZING) {
+                    final PartitionInfo partitionInfo = createPartitionInfo(partition);
+                    updatedVertices.add(consumerVertexId);
+
+                    if (consumerState == DEPLOYING) {
+                        consumerVertex.cachePartitionInfo(partitionInfo);
+                    } else {
+                        consumer.sendUpdatePartitionInfoRpcCall(
+                                Collections.singleton(partitionInfo));
+                    }
                 }
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index 4b666b5742b..fd8be042f6f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -32,12 +32,15 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 public class IntermediateResult {
 
@@ -68,6 +71,9 @@ public class IntermediateResult {
     private final Map<ConsumedPartitionGroup, MaybeOffloaded<ShuffleDescriptor[]>>
             shuffleDescriptorCache;
 
+    /** All consumer job vertex ids of this dataset. */
+    private final List<JobVertexID> consumerVertices = new ArrayList<>();
+
     public IntermediateResult(
             IntermediateDataSet intermediateDataSet,
             ExecutionJobVertex producer,
@@ -95,6 +101,10 @@ public class IntermediateResult {
         this.resultType = checkNotNull(resultType);
 
         this.shuffleDescriptorCache = new HashMap<>();
+
+        intermediateDataSet
+                .getConsumers()
+                .forEach(jobEdge -> consumerVertices.add(jobEdge.getTarget().getID()));
     }
 
     public void setPartition(int partitionNumber, IntermediateResultPartition partition) {
@@ -124,6 +134,10 @@ public class IntermediateResult {
         return partitions;
     }
 
+    public List<JobVertexID> getConsumerVertices() {
+        return consumerVertices;
+    }
+
     /**
      * Returns the partition with the given ID.
      *
@@ -162,20 +176,60 @@ public class IntermediateResult {
         return numParallelProducers;
     }
 
-    ExecutionJobVertex getConsumerExecutionJobVertex() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        final JobVertexID consumerJobVertexId = consumer.getTarget().getID();
-        return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId));
+    /**
+     * Currently, this method is only used to compute the maximum number of consumers. For dynamic
+     * graph, it should be called before adaptively deciding the downstream consumer parallelism.
+     */
+    int getConsumersParallelism() {
+        List<JobEdge> consumers = intermediateDataSet.getConsumers();
+        checkState(!consumers.isEmpty());
+
+        InternalExecutionGraphAccessor graph = getProducer().getGraph();
+        int consumersParallelism =
+                graph.getJobVertex(consumers.get(0).getTarget().getID()).getParallelism();
+        if (consumers.size() == 1) {
+            return consumersParallelism;
+        }
+
+        // sanity check, all consumer vertices must have the same parallelism:
+        // 1. for vertices that are not assigned a parallelism initially (for example, dynamic
+        // graph), the parallelisms will all be -1 (parallelism not decided yet)
+        // 2. for vertices that are initially assigned a parallelism, the parallelisms must be the
+        // same, which is guaranteed at compilation phase
+        for (JobVertexID jobVertexID : consumerVertices) {
+            checkState(
+                    consumersParallelism == graph.getJobVertex(jobVertexID).getParallelism(),
+                    "Consumers must have the same parallelism.");
+        }
+        return consumersParallelism;
+    }
+
+    int getConsumersMaxParallelism() {
+        List<JobEdge> consumers = intermediateDataSet.getConsumers();
+        checkState(!consumers.isEmpty());
+
+        InternalExecutionGraphAccessor graph = getProducer().getGraph();
+        int consumersMaxParallelism =
+                graph.getJobVertex(consumers.get(0).getTarget().getID()).getMaxParallelism();
+        if (consumers.size() == 1) {
+            return consumersMaxParallelism;
+        }
+
+        // sanity check, all consumer vertices must have the same max parallelism
+        for (JobVertexID jobVertexID : consumerVertices) {
+            checkState(
+                    consumersMaxParallelism == graph.getJobVertex(jobVertexID).getMaxParallelism(),
+                    "Consumers must have the same max parallelism.");
+        }
+        return consumersMaxParallelism;
     }
 
     public DistributionPattern getConsumingDistributionPattern() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        return consumer.getDistributionPattern();
+        return intermediateDataSet.getDistributionPattern();
     }
 
     public boolean isBroadcast() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        return consumer.isBroadcast();
+        return intermediateDataSet.isBroadcast();
     }
 
     public int getConnectionIndex() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
index 5aef1b07471..9b9c176a3d9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
@@ -21,11 +21,13 @@ package org.apache.flink.runtime.executiongraph;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 
+import java.util.HashSet;
 import java.util.List;
-import java.util.Optional;
+import java.util.Set;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -47,6 +49,12 @@ public class IntermediateResultPartition {
     /** Whether this partition has produced some data. */
     private boolean hasDataProduced = false;
 
+    /**
+     * Releasable {@link ConsumedPartitionGroup}s for this result partition. This result partition
+     * can be released if all {@link ConsumedPartitionGroup}s are releasable.
+     */
+    private final Set<ConsumedPartitionGroup> releasablePartitionGroups = new HashSet<>();
+
     public IntermediateResultPartition(
             IntermediateResult totalResult,
             ExecutionVertex producer,
@@ -58,6 +66,25 @@ public class IntermediateResultPartition {
         this.edgeManager = edgeManager;
     }
 
+    public void markPartitionGroupReleasable(ConsumedPartitionGroup partitionGroup) {
+        releasablePartitionGroups.add(partitionGroup);
+    }
+
+    public boolean canBeReleased() {
+        if (releasablePartitionGroups.size()
+                != edgeManager.getNumberOfConsumedPartitionGroupsById(partitionId)) {
+            return false;
+        }
+        for (JobVertexID jobVertexId : totalResult.getConsumerVertices()) {
+            // for dynamic graph, if any consumer vertex is still not initialized, this result
+            // partition can not be released
+            if (!producer.getExecutionGraphAccessor().getJobVertex(jobVertexId).isInitialized()) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     public ExecutionVertex getProducer() {
         return producer;
     }
@@ -78,15 +105,8 @@ public class IntermediateResultPartition {
         return totalResult.getResultType();
     }
 
-    public ConsumerVertexGroup getConsumerVertexGroup() {
-        Optional<ConsumerVertexGroup> consumerVertexGroup = getConsumerVertexGroupOptional();
-        checkState(consumerVertexGroup.isPresent());
-        return consumerVertexGroup.get();
-    }
-
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroupOptional() {
-        return Optional.ofNullable(
-                getEdgeManager().getConsumerVertexGroupForPartition(partitionId));
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return getEdgeManager().getConsumerVertexGroupsForPartition(partitionId);
     }
 
     public List<ConsumedPartitionGroup> getConsumedPartitionGroups() {
@@ -106,12 +126,13 @@ public class IntermediateResultPartition {
 
     private int computeNumberOfSubpartitions() {
         if (!getProducer().getExecutionGraphAccessor().isDynamic()) {
-            ConsumerVertexGroup consumerVertexGroup = getConsumerVertexGroup();
-            checkState(consumerVertexGroup.size() > 0);
+            List<ConsumerVertexGroup> consumerVertexGroups = getConsumerVertexGroups();
+            checkState(!consumerVertexGroups.isEmpty());
 
             // The produced data is partitioned among a number of subpartitions, one for each
-            // consuming sub task.
-            return consumerVertexGroup.size();
+            // consuming sub task. All vertex groups must have the same number of consumers
+            // for non-dynamic graph.
+            return consumerVertexGroups.get(0).size();
         } else {
             if (totalResult.isBroadcast()) {
                 // for dynamic graph and broadcast result, we only produced one subpartition,
@@ -124,18 +145,16 @@ public class IntermediateResultPartition {
     }
 
     private int computeNumberOfMaxPossiblePartitionConsumers() {
-        final ExecutionJobVertex consumerJobVertex =
-                getIntermediateResult().getConsumerExecutionJobVertex();
         final DistributionPattern distributionPattern =
                 getIntermediateResult().getConsumingDistributionPattern();
 
         // decide the max possible consumer job vertex parallelism
-        int maxConsumerJobVertexParallelism = consumerJobVertex.getParallelism();
+        int maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersParallelism();
         if (maxConsumerJobVertexParallelism <= 0) {
+            maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersMaxParallelism();
             checkState(
-                    consumerJobVertex.getMaxParallelism() > 0,
+                    maxConsumerJobVertexParallelism > 0,
                     "Neither the parallelism nor the max parallelism of a job vertex is set");
-            maxConsumerJobVertexParallelism = consumerJobVertex.getMaxParallelism();
         }
 
         // compute number of subpartitions according to the distribution pattern
@@ -163,6 +182,7 @@ public class IntermediateResultPartition {
                 consumedPartitionGroup.partitionUnfinished();
             }
         }
+        releasablePartitionGroups.clear();
         hasDataProduced = false;
         for (ConsumedPartitionGroup consumedPartitionGroup : getConsumedPartitionGroups()) {
             totalResult.clearCachedInformationForPartitionGroup(consumedPartitionGroup);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
index 6e7181ffd03..39d7fe72547 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
@@ -230,12 +230,12 @@ public class RestartPipelinedRegionFailoverStrategy implements FailoverStrategy
 
         for (SchedulingExecutionVertex vertex : regionToRestart.getVertices()) {
             for (SchedulingResultPartition producedPartition : vertex.getProducedResults()) {
-                final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                        producedPartition.getConsumerVertexGroup();
-                if (consumerVertexGroup.isPresent()
-                        && !visitedConsumerVertexGroups.contains(consumerVertexGroup.get())) {
-                    visitedConsumerVertexGroups.add(consumerVertexGroup.get());
-                    consumerVertexGroupsToVisit.add(consumerVertexGroup.get());
+                for (ConsumerVertexGroup consumerVertexGroup :
+                        producedPartition.getConsumerVertexGroups()) {
+                    if (!visitedConsumerVertexGroups.contains(consumerVertexGroup)) {
+                        visitedConsumerVertexGroups.add(consumerVertexGroup);
+                        consumerVertexGroupsToVisit.add(consumerVertexGroup);
+                    }
                 }
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
index de353be7ba8..6428284cd49 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
@@ -32,7 +32,6 @@ import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
 
@@ -116,23 +115,20 @@ public final class SchedulingPipelinedRegionComputeUtil {
                     if (producedResult.getResultType().mustBePipelinedConsumed()) {
                         continue;
                     }
-                    final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                            producedResult.getConsumerVertexGroup();
-                    if (!consumerVertexGroup.isPresent()) {
-                        continue;
-                    }
-
-                    for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
-                        SchedulingExecutionVertex consumerVertex =
-                                executionVertexRetriever.apply(consumerVertexId);
-                        // Skip the ConsumerVertexGroup if its vertices are outside current
-                        // regions and cannot be merged
-                        if (!vertexToRegion.containsKey(consumerVertex)) {
-                            break;
-                        }
-                        if (!currentRegion.contains(consumerVertex)) {
-                            currentRegionOutEdges.add(
-                                    regionIndices.get(vertexToRegion.get(consumerVertex)));
+                    for (ConsumerVertexGroup consumerVertexGroup :
+                            producedResult.getConsumerVertexGroups()) {
+                        for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+                            SchedulingExecutionVertex consumerVertex =
+                                    executionVertexRetriever.apply(consumerVertexId);
+                            // Skip the ConsumerVertexGroup if its vertices are outside current
+                            // regions and cannot be merged
+                            if (!vertexToRegion.containsKey(consumerVertex)) {
+                                break;
+                            }
+                            if (!currentRegion.contains(consumerVertex)) {
+                                currentRegionOutEdges.add(
+                                        regionIndices.get(vertexToRegion.get(consumerVertex)));
+                            }
                         }
                     }
                 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
index aa45dd4b56f..2d55cf4c639 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
@@ -56,6 +56,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
@@ -93,7 +94,7 @@ public class NettyShuffleEnvironment
 
     private final FileChannelManager fileChannelManager;
 
-    private final Map<InputGateID, SingleInputGate> inputGatesById;
+    private final Map<InputGateID, Set<SingleInputGate>> inputGatesById;
 
     private final ResultPartitionFactory resultPartitionFactory;
 
@@ -169,7 +170,7 @@ public class NettyShuffleEnvironment
     }
 
     @VisibleForTesting
-    public Optional<InputGate> getInputGate(InputGateID id) {
+    public Optional<Collection<SingleInputGate>> getInputGate(InputGateID id) {
         return Optional.ofNullable(inputGatesById.get(id));
     }
 
@@ -260,8 +261,24 @@ public class NettyShuffleEnvironment
                 InputGateID id =
                         new InputGateID(
                                 igdd.getConsumedResultId(), ownerContext.getExecutionAttemptID());
-                inputGatesById.put(id, inputGate);
-                inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id));
+                Set<SingleInputGate> inputGateSet =
+                        inputGatesById.computeIfAbsent(
+                                id, ignored -> ConcurrentHashMap.newKeySet());
+                inputGateSet.add(inputGate);
+                inputGatesById.put(id, inputGateSet);
+                inputGate
+                        .getCloseFuture()
+                        .thenRun(
+                                () ->
+                                        inputGatesById.computeIfPresent(
+                                                id,
+                                                (key, value) -> {
+                                                    value.remove(inputGate);
+                                                    if (value.isEmpty()) {
+                                                        return null;
+                                                    }
+                                                    return value;
+                                                }));
                 inputGates[gateIndex] = inputGate;
             }
 
@@ -297,17 +314,20 @@ public class NettyShuffleEnvironment
         IntermediateDataSetID intermediateResultPartitionID =
                 partitionInfo.getIntermediateDataSetID();
         InputGateID id = new InputGateID(intermediateResultPartitionID, consumerID);
-        SingleInputGate inputGate = inputGatesById.get(id);
-        if (inputGate == null) {
+        Set<SingleInputGate> inputGates = inputGatesById.get(id);
+        if (inputGates == null || inputGates.isEmpty()) {
             return false;
         }
+
         ShuffleDescriptor shuffleDescriptor = partitionInfo.getShuffleDescriptor();
         checkArgument(
                 shuffleDescriptor instanceof NettyShuffleDescriptor,
                 "Tried to update unknown channel with unknown ShuffleDescriptor %s.",
                 shuffleDescriptor.getClass().getName());
-        inputGate.updateInputChannel(
-                taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+        for (SingleInputGate inputGate : inputGates) {
+            inputGate.updateInputChannel(
+                    taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+        }
         return true;
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
index d6b3abc3b52..2aad2613996 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
@@ -20,7 +20,8 @@ package org.apache.flink.runtime.jobgraph;
 
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 
-import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
@@ -39,11 +40,16 @@ public class IntermediateDataSet implements java.io.Serializable {
 
     private final JobVertex producer; // the operation that produced this data set
 
-    @Nullable private JobEdge consumer;
+    // All consumers must have the same partitioner and parallelism
+    private final List<JobEdge> consumers = new ArrayList<>();
 
     // The type of partition to use at runtime
     private final ResultPartitionType resultType;
 
+    private DistributionPattern distributionPattern;
+
+    private boolean isBroadcast;
+
     // --------------------------------------------------------------------------------------------
 
     public IntermediateDataSet(
@@ -63,9 +69,16 @@ public class IntermediateDataSet implements java.io.Serializable {
         return producer;
     }
 
-    @Nullable
-    public JobEdge getConsumer() {
-        return consumer;
+    public List<JobEdge> getConsumers() {
+        return this.consumers;
+    }
+
+    public boolean isBroadcast() {
+        return isBroadcast;
+    }
+
+    public DistributionPattern getDistributionPattern() {
+        return distributionPattern;
     }
 
     public ResultPartitionType getResultType() {
@@ -75,10 +88,19 @@ public class IntermediateDataSet implements java.io.Serializable {
     // --------------------------------------------------------------------------------------------
 
     public void addConsumer(JobEdge edge) {
-        checkState(
-                this.consumer == null,
-                "Currently one IntermediateDataSet can have at most one consumer.");
-        this.consumer = edge;
+        // sanity check
+        checkState(id.equals(edge.getSourceId()), "Incompatible dataset id.");
+
+        if (consumers.isEmpty()) {
+            distributionPattern = edge.getDistributionPattern();
+            isBroadcast = edge.isBroadcast();
+        } else {
+            checkState(
+                    distributionPattern == edge.getDistributionPattern(),
+                    "Incompatible distribution pattern.");
+            checkState(isBroadcast == edge.isBroadcast(), "Incompatible broadcast type.");
+        }
+        consumers.add(edge);
     }
 
     // --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
index 9772ff4dbb1..4649303c144 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
@@ -44,10 +44,7 @@ public class JobEdge implements java.io.Serializable {
     private SubtaskStateMapper upstreamSubtaskStateMapper = SubtaskStateMapper.ROUND_ROBIN;
 
     /** The data set at the source of the edge, may be null if the edge is not yet connected. */
-    private IntermediateDataSet source;
-
-    /** The id of the source intermediate data set. */
-    private IntermediateDataSetID sourceId;
+    private final IntermediateDataSet source;
 
     /**
      * Optional name for the data shipping strategy (forward, partition hash, rebalance, ...), to be
@@ -55,7 +52,7 @@ public class JobEdge implements java.io.Serializable {
      */
     private String shipStrategyName;
 
-    private boolean isBroadcast;
+    private final boolean isBroadcast;
 
     private boolean isForward;
 
@@ -74,36 +71,20 @@ public class JobEdge implements java.io.Serializable {
      * @param source The data set that is at the source of this edge.
      * @param target The operation that is at the target of this edge.
      * @param distributionPattern The pattern that defines how the connection behaves in parallel.
+     * @param isBroadcast Whether the source broadcasts data to the target.
      */
     public JobEdge(
-            IntermediateDataSet source, JobVertex target, DistributionPattern distributionPattern) {
+            IntermediateDataSet source,
+            JobVertex target,
+            DistributionPattern distributionPattern,
+            boolean isBroadcast) {
         if (source == null || target == null || distributionPattern == null) {
             throw new NullPointerException();
         }
         this.target = target;
         this.distributionPattern = distributionPattern;
         this.source = source;
-        this.sourceId = source.getId();
-    }
-
-    /**
-     * Constructs a new job edge that refers to an intermediate result via the Id, rather than
-     * directly through the intermediate data set structure.
-     *
-     * @param sourceId The id of the data set that is at the source of this edge.
-     * @param target The operation that is at the target of this edge.
-     * @param distributionPattern The pattern that defines how the connection behaves in parallel.
-     */
-    public JobEdge(
-            IntermediateDataSetID sourceId,
-            JobVertex target,
-            DistributionPattern distributionPattern) {
-        if (sourceId == null || target == null || distributionPattern == null) {
-            throw new NullPointerException();
-        }
-        this.target = target;
-        this.distributionPattern = distributionPattern;
-        this.sourceId = sourceId;
+        this.isBroadcast = isBroadcast;
     }
 
     /**
@@ -140,11 +121,7 @@ public class JobEdge implements java.io.Serializable {
      * @return The ID of the consumed data set.
      */
     public IntermediateDataSetID getSourceId() {
-        return sourceId;
-    }
-
-    public boolean isIdReference() {
-        return this.source == null;
+        return source.getId();
     }
 
     // --------------------------------------------------------------------------------------------
@@ -173,11 +150,6 @@ public class JobEdge implements java.io.Serializable {
         return isBroadcast;
     }
 
-    /** Sets whether the edge is broadcast edge. */
-    public void setBroadcast(boolean broadcast) {
-        isBroadcast = broadcast;
-    }
-
     /** Gets whether the edge is forward edge. */
     public boolean isForward() {
         return isForward;
@@ -268,6 +240,6 @@ public class JobEdge implements java.io.Serializable {
 
     @Override
     public String toString() {
-        return String.format("%s --> %s [%s]", sourceId, target, distributionPattern.name());
+        return String.format("%s --> %s [%s]", source.getId(), target, distributionPattern.name());
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
index e8baef4d8e7..bb821adda7b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
@@ -458,10 +458,9 @@ public class JobGraph implements Serializable {
     private void addNodesThatHaveNoNewPredecessors(
             JobVertex start, List<JobVertex> target, Set<JobVertex> remaining) {
 
-        // forward traverse over all produced data sets
+        // forward traverse over all produced data sets and all their consumers
         for (IntermediateDataSet dataSet : start.getProducedDataSets()) {
-            JobEdge edge = dataSet.getConsumer();
-            if (edge != null) {
+            for (JobEdge edge : dataSet.getConsumers()) {
 
                 // a vertex can be added, if it has no predecessors that are still in the
                 // 'remaining' set
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
index 36d1a1e5487..717bd6d4376 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
@@ -36,7 +36,9 @@ import javax.annotation.Nullable;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -72,8 +74,8 @@ public class JobVertex implements java.io.Serializable {
      */
     private final List<OperatorIDPair> operatorIDs;
 
-    /** List of produced data sets, one per writer. */
-    private final ArrayList<IntermediateDataSet> results = new ArrayList<>();
+    /** Produced data sets, one per writer. */
+    private final Map<IntermediateDataSetID, IntermediateDataSet> results = new LinkedHashMap<>();
 
     /** List of edges with incoming data. One per Reader. */
     private final ArrayList<JobEdge> inputs = new ArrayList<>();
@@ -374,7 +376,7 @@ public class JobVertex implements java.io.Serializable {
     }
 
     public List<IntermediateDataSet> getProducedDataSets() {
-        return this.results;
+        return new ArrayList<>(results.values());
     }
 
     public List<JobEdge> getInputs() {
@@ -481,30 +483,37 @@ public class JobVertex implements java.io.Serializable {
     }
 
     // --------------------------------------------------------------------------------------------
-    public IntermediateDataSet createAndAddResultDataSet(
+    public IntermediateDataSet getOrCreateResultDataSet(
             IntermediateDataSetID id, ResultPartitionType partitionType) {
-
-        IntermediateDataSet result = new IntermediateDataSet(id, partitionType, this);
-        this.results.add(result);
-        return result;
+        return this.results.computeIfAbsent(
+                id, key -> new IntermediateDataSet(id, partitionType, this));
     }
 
     public JobEdge connectNewDataSetAsInput(
             JobVertex input, DistributionPattern distPattern, ResultPartitionType partitionType) {
+        return connectNewDataSetAsInput(input, distPattern, partitionType, false);
+    }
+
+    public JobEdge connectNewDataSetAsInput(
+            JobVertex input,
+            DistributionPattern distPattern,
+            ResultPartitionType partitionType,
+            boolean isBroadcast) {
         return connectNewDataSetAsInput(
-                input, distPattern, partitionType, new IntermediateDataSetID());
+                input, distPattern, partitionType, new IntermediateDataSetID(), isBroadcast);
     }
 
     public JobEdge connectNewDataSetAsInput(
             JobVertex input,
             DistributionPattern distPattern,
             ResultPartitionType partitionType,
-            IntermediateDataSetID intermediateDataSetId) {
+            IntermediateDataSetID intermediateDataSetId,
+            boolean isBroadcast) {
 
         IntermediateDataSet dataSet =
-                input.createAndAddResultDataSet(intermediateDataSetId, partitionType);
+                input.getOrCreateResultDataSet(intermediateDataSetId, partitionType);
 
-        JobEdge edge = new JobEdge(dataSet, this, distPattern);
+        JobEdge edge = new JobEdge(dataSet, this, distPattern, isBroadcast);
         this.inputs.add(edge);
         dataSet.addConsumer(edge);
         return edge;
@@ -525,13 +534,7 @@ public class JobVertex implements java.io.Serializable {
     }
 
     public boolean hasNoConnectedInputs() {
-        for (JobEdge edge : inputs) {
-            if (!edge.isIdReference()) {
-                return false;
-            }
-        }
-
-        return true;
+        return inputs.isEmpty();
     }
 
     public void markContainsSources() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index 0fac885dd3f..1506152fa38 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -45,7 +45,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -143,15 +142,22 @@ public class SsgNetworkMemoryCalculationUtils {
         Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
         List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
 
-        for (int i = 0; i < producedDataSets.size(); i++) {
-            IntermediateDataSet producedDataSet = producedDataSets.get(i);
-            JobEdge outputEdge = checkNotNull(producedDataSet.getConsumer());
-            ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
-            int maxNum =
-                    EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
-                            ejv.getParallelism(),
-                            consumerJobVertex.getParallelism(),
-                            outputEdge.getDistributionPattern());
+        checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph.");
+        for (IntermediateDataSet producedDataSet : producedDataSets) {
+            int maxNum = 0;
+            List<JobEdge> outputEdges = producedDataSet.getConsumers();
+
+            if (!outputEdges.isEmpty()) {
+                // for non-dynamic graph, the consumer vertices' parallelisms and distribution
+                // patterns must be the same
+                JobEdge outputEdge = outputEdges.get(0);
+                ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
+                maxNum =
+                        EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
+                                ejv.getParallelism(),
+                                consumerJobVertex.getParallelism(),
+                                outputEdge.getDistributionPattern());
+            }
             ret.put(producedDataSet.getId(), maxNum);
         }
 
@@ -177,7 +183,9 @@ public class SsgNetworkMemoryCalculationUtils {
                         ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
                 SubpartitionIndexRange subpartitionIndexRange =
                         TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
-                                resultPartition, vertex.getParallelSubtaskIndex());
+                                partitionGroup.getNumConsumers(),
+                                resultPartition,
+                                vertex.getParallelSubtaskIndex());
 
                 ret.merge(
                         partitionGroup.getIntermediateDataSetID(),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
index 486ae28b264..f2fd573e613 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
@@ -262,7 +262,7 @@ public class DefaultExecutionTopology implements SchedulingTopology {
             List<DefaultResultPartition> producedPartitions =
                     generateProducedSchedulingResultPartition(
                             vertex.getProducedPartitions(),
-                            edgeManager::getConsumerVertexGroupForPartition);
+                            edgeManager::getConsumerVertexGroupsForPartition);
 
             producedPartitions.forEach(
                     partition -> resultPartitionsById.put(partition.getId(), partition));
@@ -285,8 +285,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
     private static List<DefaultResultPartition> generateProducedSchedulingResultPartition(
             Map<IntermediateResultPartitionID, IntermediateResultPartition>
                     producedIntermediatePartitions,
-            Function<IntermediateResultPartitionID, ConsumerVertexGroup>
-                    partitionConsumerVertexGroupRetriever) {
+            Function<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+                    partitionConsumerVertexGroupsRetriever) {
 
         List<DefaultResultPartition> producedSchedulingPartitions =
                 new ArrayList<>(producedIntermediatePartitions.size());
@@ -305,8 +305,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
                                                                 ? ResultPartitionState.CONSUMABLE
                                                                 : ResultPartitionState.CREATED,
                                                 () ->
-                                                        partitionConsumerVertexGroupRetriever.apply(
-                                                                irp.getPartitionId()),
+                                                        partitionConsumerVertexGroupsRetriever
+                                                                .apply(irp.getPartitionId()),
                                                 irp::getConsumedPartitionGroups)));
 
         return producedSchedulingPartitions;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
index b46c54bd58e..2ae54af4304 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
@@ -27,7 +27,6 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
 
 import java.util.List;
-import java.util.Optional;
 import java.util.function.Supplier;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -45,7 +44,7 @@ class DefaultResultPartition implements SchedulingResultPartition {
 
     private DefaultExecutionVertex producer;
 
-    private final Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier;
+    private final Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier;
 
     private final Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier;
 
@@ -54,13 +53,13 @@ class DefaultResultPartition implements SchedulingResultPartition {
             IntermediateDataSetID intermediateDataSetId,
             ResultPartitionType partitionType,
             Supplier<ResultPartitionState> resultPartitionStateSupplier,
-            Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier,
+            Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier,
             Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier) {
         this.resultPartitionId = checkNotNull(partitionId);
         this.intermediateDataSetId = checkNotNull(intermediateDataSetId);
         this.partitionType = checkNotNull(partitionType);
         this.resultPartitionStateSupplier = checkNotNull(resultPartitionStateSupplier);
-        this.consumerVertexGroupSupplier = checkNotNull(consumerVertexGroupSupplier);
+        this.consumerVertexGroupsSupplier = checkNotNull(consumerVertexGroupsSupplier);
         this.consumerPartitionGroupSupplier = checkNotNull(consumerPartitionGroupSupplier);
     }
 
@@ -90,8 +89,8 @@ class DefaultResultPartition implements SchedulingResultPartition {
     }
 
     @Override
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
-        return Optional.ofNullable(consumerVertexGroupSupplier.get());
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return checkNotNull(consumerVertexGroupsSupplier.get());
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
index 6e4672c48e9..2f5be93ca51 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
@@ -42,12 +42,17 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
 
     private final ResultPartitionType resultPartitionType;
 
+    /** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */
+    private final int numConsumers;
+
     private ConsumedPartitionGroup(
+            int numConsumers,
             List<IntermediateResultPartitionID> resultPartitions,
             ResultPartitionType resultPartitionType) {
         checkArgument(
                 resultPartitions.size() > 0,
                 "The size of result partitions in the ConsumedPartitionGroup should be larger than 0.");
+        this.numConsumers = numConsumers;
         this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID();
         this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType);
 
@@ -63,16 +68,18 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
     }
 
     public static ConsumedPartitionGroup fromMultiplePartitions(
+            int numConsumers,
             List<IntermediateResultPartitionID> resultPartitions,
             ResultPartitionType resultPartitionType) {
-        return new ConsumedPartitionGroup(resultPartitions, resultPartitionType);
+        return new ConsumedPartitionGroup(numConsumers, resultPartitions, resultPartitionType);
     }
 
     public static ConsumedPartitionGroup fromSinglePartition(
+            int numConsumers,
             IntermediateResultPartitionID resultPartition,
             ResultPartitionType resultPartitionType) {
         return new ConsumedPartitionGroup(
-                Collections.singletonList(resultPartition), resultPartitionType);
+                numConsumers, Collections.singletonList(resultPartition), resultPartitionType);
     }
 
     @Override
@@ -88,6 +95,14 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
         return resultPartitions.isEmpty();
     }
 
+    /**
+     * In dynamic graph cases, the number of consumers of ConsumedPartitionGroup can be different
+     * even if they contain the same IntermediateResultPartition.
+     */
+    public int getNumConsumers() {
+        return numConsumers;
+    }
+
     public IntermediateResultPartitionID getFirst() {
         return iterator().next();
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
index b52150c96cd..86cafebac6b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
@@ -24,7 +24,6 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.topology.Result;
 
 import java.util.List;
-import java.util.Optional;
 
 /** Representation of {@link IntermediateResultPartition}. */
 public interface SchedulingResultPartition
@@ -49,11 +48,11 @@ public interface SchedulingResultPartition
     ResultPartitionState getState();
 
     /**
-     * Gets the {@link ConsumerVertexGroup}.
+     * Gets the {@link ConsumerVertexGroup}s.
      *
-     * @return {@link ConsumerVertexGroup} if consumers exists, otherwise {@link Optional#empty()}.
+     * @return list of {@link ConsumerVertexGroup}s
      */
-    Optional<ConsumerVertexGroup> getConsumerVertexGroup();
+    List<ConsumerVertexGroup> getConsumerVertexGroups();
 
     /**
      * Gets the {@link ConsumedPartitionGroup}s this partition belongs to.
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
index 4f7d01e4e8f..4d217a9eb97 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
@@ -24,12 +24,12 @@ import org.apache.flink.runtime.scheduler.SchedulerOperations;
 import org.apache.flink.runtime.scheduler.SchedulingTopologyListener;
 import org.apache.flink.util.IterableUtils;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -85,11 +85,9 @@ public class VertexwiseSchedulingStrategy
 
             Set<ExecutionVertexID> consumerVertices =
                     IterableUtils.toStream(executionVertex.getProducedResults())
-                            .map(SchedulingResultPartition::getConsumerVertexGroup)
-                            .filter(Optional::isPresent)
-                            .flatMap(
-                                    consumerVertexGroup ->
-                                            IterableUtils.toStream(consumerVertexGroup.get()))
+                            .map(SchedulingResultPartition::getConsumerVertexGroups)
+                            .flatMap(Collection::stream)
+                            .flatMap(IterableUtils::toStream)
                             .collect(Collectors.toSet());
 
             maybeScheduleVertices(consumerVertices);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
new file mode 100644
index 00000000000..af1c543c509
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.blob.TestingBlobWriter;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
+import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createSchedulerAndDeploy;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests that blocking result partitions are properly released. */
+class BlockingResultPartitionReleaseTest {
+
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
+
+    private ScheduledExecutorService scheduledExecutorService;
+    private ComponentMainThreadExecutor mainThreadExecutor;
+    private ManuallyTriggeredScheduledExecutorService ioExecutor;
+
+    @BeforeEach
+    void setup() {
+        scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
+        mainThreadExecutor =
+                ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(
+                        scheduledExecutorService);
+        ioExecutor = new ManuallyTriggeredScheduledExecutorService();
+    }
+
+    @AfterEach
+    void teardown() {
+        if (scheduledExecutorService != null) {
+            scheduledExecutorService.shutdownNow();
+        }
+    }
+
+    @Test
+    void testMultipleConsumersForAdaptiveBatchScheduler() throws Exception {
+        testResultPartitionConsumedByMultiConsumers(true);
+    }
+
+    @Test
+    void testMultipleConsumersForDefaultScheduler() throws Exception {
+        testResultPartitionConsumedByMultiConsumers(false);
+    }
+
+    private void testResultPartitionConsumedByMultiConsumers(boolean isAdaptive) throws Exception {
+        int parallelism = 2;
+        JobID jobId = new JobID();
+        JobVertex producer = ExecutionGraphTestUtils.createNoOpVertex("producer", parallelism);
+        JobVertex consumer1 = ExecutionGraphTestUtils.createNoOpVertex("consumer1", parallelism);
+        JobVertex consumer2 = ExecutionGraphTestUtils.createNoOpVertex("consumer2", parallelism);
+
+        TestingPartitionTracker partitionTracker = new TestingPartitionTracker();
+        SchedulerBase scheduler =
+                createSchedulerAndDeploy(
+                        isAdaptive,
+                        jobId,
+                        producer,
+                        new JobVertex[] {consumer1, consumer2},
+                        DistributionPattern.ALL_TO_ALL,
+                        new TestingBlobWriter(Integer.MAX_VALUE),
+                        mainThreadExecutor,
+                        ioExecutor,
+                        partitionTracker,
+                        EXECUTOR_RESOURCE.getExecutor());
+        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+
+        assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+        CompletableFuture.runAsync(
+                        () -> finishJobVertex(executionGraph, consumer1.getID()),
+                        mainThreadExecutor)
+                .join();
+        ioExecutor.triggerAll();
+
+        assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+        CompletableFuture.runAsync(
+                        () -> finishJobVertex(executionGraph, consumer2.getID()),
+                        mainThreadExecutor)
+                .join();
+        ioExecutor.triggerAll();
+
+        assertThat(partitionTracker.releasedPartitions.size()).isEqualTo(parallelism);
+        for (int i = 0; i < parallelism; ++i) {
+            ExecutionJobVertex ejv = checkNotNull(executionGraph.getJobVertex(producer.getID()));
+            assertThat(
+                            partitionTracker.releasedPartitions.stream()
+                                    .map(ResultPartitionID::getPartitionId))
+                    .containsExactlyInAnyOrder(
+                            Arrays.stream(ejv.getProducedDataSets()[0].getPartitions())
+                                    .map(IntermediateResultPartition::getPartitionId)
+                                    .toArray(IntermediateResultPartitionID[]::new));
+        }
+    }
+
+    private static class TestingPartitionTracker extends NoOpJobMasterPartitionTracker {
+
+        private final List<ResultPartitionID> releasedPartitions = new ArrayList<>();
+
+        @Override
+        public void stopTrackingAndReleasePartitions(
+                Collection<ResultPartitionID> resultPartitionIds, boolean releaseOnShuffleMaster) {
+            releasedPartitions.addAll(checkNotNull(resultPartitionIds));
+        }
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index 9245fa8f966..588a91c6fd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.io.InputSplitSource;
 import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -48,6 +49,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -261,6 +263,45 @@ class DefaultExecutionGraphConstructionTest {
         assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2);
     }
 
+    @Test
+    void testMultiConsumersForOneIntermediateResult() throws Exception {
+        JobVertex v1 = new JobVertex("vertex1");
+        JobVertex v2 = new JobVertex("vertex2");
+        JobVertex v3 = new JobVertex("vertex3");
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        v2.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+        v3.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+
+        List<JobVertex> vertices = new ArrayList<>(Arrays.asList(v1, v2, v3));
+        ExecutionGraph eg = createDefaultExecutionGraph(vertices);
+        eg.attachJobGraph(vertices);
+
+        ExecutionJobVertex ejv1 = checkNotNull(eg.getJobVertex(v1.getID()));
+        assertThat(ejv1.getProducedDataSets()).hasSize(1);
+        assertThat(ejv1.getProducedDataSets()[0].getId()).isEqualTo(dataSetId);
+
+        ExecutionJobVertex ejv2 = checkNotNull(eg.getJobVertex(v2.getID()));
+        assertThat(ejv2.getInputs()).hasSize(1);
+        assertThat(ejv2.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+        ExecutionJobVertex ejv3 = checkNotNull(eg.getJobVertex(v3.getID()));
+        assertThat(ejv3.getInputs()).hasSize(1);
+        assertThat(ejv3.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+        List<ConsumedPartitionGroup> partitionGroups1 =
+                ejv2.getTaskVertices()[0].getAllConsumedPartitionGroups();
+        assertThat(partitionGroups1).hasSize(1);
+        assertThat(partitionGroups1.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+
+        List<ConsumedPartitionGroup> partitionGroups2 =
+                ejv3.getTaskVertices()[0].getAllConsumedPartitionGroups();
+        assertThat(partitionGroups2).hasSize(1);
+        assertThat(partitionGroups2.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+    }
+
     @Test
     void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
         JobVertex v1 = new JobVertex("source");
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index e3b603a8790..28b44595a26 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -85,7 +85,7 @@ class EdgeManagerBuildUtilTest {
 
             IntermediateResultPartition partition =
                     ev.getProducedPartitions().values().iterator().next();
-            ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroup();
+            ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroups().get(0);
             int actual = consumerVertexGroup.size();
             if (actual > actualMaxForUpstream) {
                 actualMaxForUpstream = actual;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
index 81165da6193..ad94749e7e8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
@@ -34,9 +34,15 @@ import org.apache.flink.testutils.executor.TestExecutorExtension;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
+import java.util.Arrays;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
 
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link EdgeManager}. */
@@ -50,23 +56,7 @@ class EdgeManagerTest {
     void testGetConsumedPartitionGroup() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
-
-        v1.setParallelism(2);
-        v2.setParallelism(2);
-
-        v1.setInvokableClass(NoOpInvokable.class);
-        v2.setInvokableClass(NoOpInvokable.class);
-
-        v2.connectNewDataSetAsInput(
-                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
-
-        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(v1, v2);
-        SchedulerBase scheduler =
-                SchedulerTestingUtils.createScheduler(
-                        jobGraph,
-                        ComponentMainThreadExecutorServiceAdapter.forMainThread(),
-                        EXECUTOR_RESOURCE.getExecutor());
-        ExecutionGraph eg = scheduler.getExecutionGraph();
+        ExecutionGraph eg = buildExecutionGraph(v1, v2, 2, 2, ALL_TO_ALL);
 
         ConsumedPartitionGroup groupRetrievedByDownstreamVertex =
                 Objects.requireNonNull(eg.getJobVertex(v2.getID()))
@@ -86,9 +76,7 @@ class EdgeManagerTest {
                 .isEqualTo(groupRetrievedByDownstreamVertex);
 
         ConsumedPartitionGroup groupRetrievedByScheduledResultPartition =
-                scheduler
-                        .getExecutionGraph()
-                        .getSchedulingTopology()
+                eg.getSchedulingTopology()
                         .getResultPartition(consumedPartition.getPartitionId())
                         .getConsumedPartitionGroups()
                         .get(0);
@@ -96,4 +84,64 @@ class EdgeManagerTest {
         assertThat(groupRetrievedByScheduledResultPartition)
                 .isEqualTo(groupRetrievedByDownstreamVertex);
     }
+
+    @Test
+    public void testCalculateNumberOfConsumers() throws Exception {
+        testCalculateNumberOfConsumers(5, 2, ALL_TO_ALL, new int[] {2, 2});
+        testCalculateNumberOfConsumers(5, 2, POINTWISE, new int[] {1, 1});
+        testCalculateNumberOfConsumers(2, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+        testCalculateNumberOfConsumers(2, 5, POINTWISE, new int[] {3, 3, 3, 2, 2});
+        testCalculateNumberOfConsumers(5, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+        testCalculateNumberOfConsumers(5, 5, POINTWISE, new int[] {1, 1, 1, 1, 1});
+    }
+
+    private void testCalculateNumberOfConsumers(
+            int producerParallelism,
+            int consumerParallelism,
+            DistributionPattern distributionPattern,
+            int[] expectedConsumers)
+            throws Exception {
+        JobVertex producer = new JobVertex("producer");
+        JobVertex consumer = new JobVertex("consumer");
+        ExecutionGraph eg =
+                buildExecutionGraph(
+                        producer,
+                        consumer,
+                        producerParallelism,
+                        consumerParallelism,
+                        distributionPattern);
+        List<ConsumedPartitionGroup> partitionGroups =
+                Arrays.stream(checkNotNull(eg.getJobVertex(consumer.getID())).getTaskVertices())
+                        .flatMap(ev -> ev.getAllConsumedPartitionGroups().stream())
+                        .collect(Collectors.toList());
+        int index = 0;
+        for (ConsumedPartitionGroup partitionGroup : partitionGroups) {
+            assertThat(partitionGroup.getNumConsumers()).isEqualTo(expectedConsumers[index++]);
+        }
+    }
+
+    private ExecutionGraph buildExecutionGraph(
+            JobVertex producer,
+            JobVertex consumer,
+            int producerParallelism,
+            int consumerParallelism,
+            DistributionPattern distributionPattern)
+            throws Exception {
+        producer.setParallelism(producerParallelism);
+        consumer.setParallelism(consumerParallelism);
+
+        producer.setInvokableClass(NoOpInvokable.class);
+        consumer.setInvokableClass(NoOpInvokable.class);
+
+        consumer.connectNewDataSetAsInput(
+                producer, distributionPattern, ResultPartitionType.BLOCKING);
+
+        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(producer, consumer);
+        SchedulerBase scheduler =
+                SchedulerTestingUtils.createScheduler(
+                        jobGraph,
+                        ComponentMainThreadExecutorServiceAdapter.forMainThread(),
+                        EXECUTOR_RESOURCE.getExecutor());
+        return scheduler.getExecutionGraph();
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
index ff2baf807e1..a9e7df9470d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
 
@@ -43,6 +44,7 @@ import javax.annotation.Nullable;
 import java.lang.reflect.Field;
 import java.time.Duration;
 import java.util.List;
+import java.util.Objects;
 import java.util.Random;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
@@ -249,6 +251,23 @@ public class ExecutionGraphTestUtils {
         }
     }
 
+    public static void finishJobVertex(ExecutionGraph executionGraph, JobVertexID jobVertexId) {
+        for (ExecutionVertex vertex :
+                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexId))
+                        .getTaskVertices()) {
+            finishExecutionVertex(executionGraph, vertex);
+        }
+    }
+
+    public static void finishExecutionVertex(
+            ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
+        executionGraph.updateState(
+                new TaskExecutionStateTransition(
+                        new TaskExecutionState(
+                                executionVertex.getCurrentExecutionAttempt().getAttemptId(),
+                                ExecutionState.FINISHED)));
+    }
+
     /**
      * Takes all vertices in the given ExecutionGraph and switches their current execution to
      * FINISHED.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index d847e9ede68..dda16596751 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -170,7 +170,7 @@ class ExecutionJobVertexTest {
             int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception {
         JobVertex jobVertex = new JobVertex("testVertex");
         jobVertex.setInvokableClass(AbstractInvokable.class);
-        jobVertex.createAndAddResultDataSet(
+        jobVertex.getOrCreateResultDataSet(
                 new IntermediateDataSetID(), ResultPartitionType.BLOCKING);
 
         if (maxParallelism > 0) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 5ff5d9f201c..aa906caf674 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
 import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
@@ -151,6 +152,32 @@ public class IntermediateResultPartitionTest {
         assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
+    @Test
+    void testReleasePartitionGroups() throws Exception {
+        IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
+
+        IntermediateResultPartition partition1 = result.getPartitions()[0];
+        IntermediateResultPartition partition2 = result.getPartitions()[1];
+        assertThat(partition1.canBeReleased()).isFalse();
+        assertThat(partition2.canBeReleased()).isFalse();
+
+        List<ConsumedPartitionGroup> consumedPartitionGroup1 =
+                partition1.getConsumedPartitionGroups();
+        List<ConsumedPartitionGroup> consumedPartitionGroup2 =
+                partition2.getConsumedPartitionGroups();
+        assertThat(consumedPartitionGroup1).isEqualTo(consumedPartitionGroup2);
+
+        assertThat(consumedPartitionGroup1).hasSize(2);
+        partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(0));
+        assertThat(partition1.canBeReleased()).isFalse();
+
+        partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(1));
+        assertThat(partition1.canBeReleased()).isTrue();
+
+        result.resetForNewExecution();
+        assertThat(partition1.canBeReleased()).isFalse();
+    }
+
     @Test
     void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7));
@@ -245,11 +272,24 @@ public class IntermediateResultPartitionTest {
             v2.setMaxParallelism(consumerMaxParallelism);
         }
 
-        v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
+        final JobVertex v3 = new JobVertex("v3");
+        v3.setInvokableClass(NoOpInvokable.class);
+        if (consumerParallelism > 0) {
+            v3.setParallelism(consumerParallelism);
+        }
+        if (consumerMaxParallelism > 0) {
+            v3.setMaxParallelism(consumerMaxParallelism);
+        }
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        v2.connectNewDataSetAsInput(
+                v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+        v3.connectNewDataSetAsInput(
+                v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
 
         final JobGraph jobGraph =
                 JobGraphBuilder.newBatchJobGraphBuilder()
-                        .addJobVertices(Arrays.asList(v1, v2))
+                        .addJobVertices(Arrays.asList(v1, v2, v3))
                         .build();
 
         final Configuration configuration = new Configuration();
@@ -287,15 +327,23 @@ public class IntermediateResultPartitionTest {
         source.setInvokableClass(NoOpInvokable.class);
         source.setParallelism(parallelism);
 
-        JobVertex sink = new JobVertex("v2");
-        sink.setInvokableClass(NoOpInvokable.class);
-        sink.setParallelism(parallelism);
+        JobVertex sink1 = new JobVertex("v2");
+        sink1.setInvokableClass(NoOpInvokable.class);
+        sink1.setParallelism(parallelism);
+
+        JobVertex sink2 = new JobVertex("v3");
+        sink2.setInvokableClass(NoOpInvokable.class);
+        sink2.setParallelism(parallelism);
 
-        sink.connectNewDataSetAsInput(source, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        sink1.connectNewDataSetAsInput(
+                source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
+        sink2.connectNewDataSetAsInput(
+                source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
 
         ScheduledExecutorService executorService = new DirectScheduledExecutorService();
 
-        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink);
+        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink1, sink2);
 
         SchedulerBase scheduler =
                 new DefaultSchedulerBuilder(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
index 7c00da22a85..4bf976c76b0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.executiongraph;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.blob.BlobWriter;
 import org.apache.flink.runtime.blob.TestingBlobWriter;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
@@ -27,21 +26,14 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAda
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
 import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
-import org.apache.flink.runtime.jobgraph.JobGraph;
-import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
 import org.apache.flink.runtime.jobgraph.JobVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobmaster.LogicalSlot;
-import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
-import org.apache.flink.runtime.scheduler.DefaultScheduler;
-import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
-import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
 
@@ -50,17 +42,16 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
 
 import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishExecutionVertex;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
@@ -119,7 +110,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -132,8 +123,7 @@ class RemoveCachedShuffleDescriptorTest {
 
         // For the all-to-all edge, we transition all downstream tasks to finished
         CompletableFuture.runAsync(
-                        () -> transitionTasksToFinished(executionGraph, v2.getID()),
-                        mainThreadExecutor)
+                        () -> finishJobVertex(executionGraph, v2.getID()), mainThreadExecutor)
                 .join();
         ioExecutor.triggerAll();
 
@@ -164,7 +154,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -206,7 +196,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -222,7 +212,7 @@ class RemoveCachedShuffleDescriptorTest {
                 Objects.requireNonNull(executionGraph.getJobVertex(v2.getID()))
                         .getTaskVertices()[0];
         CompletableFuture.runAsync(
-                        () -> transitionTaskToFinished(executionGraph, ev21), mainThreadExecutor)
+                        () -> finishExecutionVertex(executionGraph, ev21), mainThreadExecutor)
                 .join();
         ioExecutor.triggerAll();
 
@@ -263,7 +253,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -292,42 +282,27 @@ class RemoveCachedShuffleDescriptorTest {
         assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
-    private DefaultScheduler createSchedulerAndDeploy(
+    private SchedulerBase createSchedulerAndDeploy(
             JobID jobId,
             JobVertex v1,
             JobVertex v2,
             DistributionPattern distributionPattern,
             BlobWriter blobWriter)
             throws Exception {
-
-        v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
-
-        final List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2));
-        final DefaultScheduler scheduler =
-                createScheduler(jobId, ordered, blobWriter, mainThreadExecutor, ioExecutor);
-        final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
-        final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
-
-        CompletableFuture.runAsync(
-                        () -> {
-                            try {
-                                // Deploy upstream source vertices
-                                deployTasks(executionGraph, v1.getID(), slotBuilder);
-                                // Transition upstream vertices into FINISHED
-                                transitionTasksToFinished(executionGraph, v1.getID());
-                                // Deploy downstream sink vertices
-                                deployTasks(executionGraph, v2.getID(), slotBuilder);
-                            } catch (Exception e) {
-                                throw new RuntimeException("Exceptions shouldn't happen here.", e);
-                            }
-                        },
-                        mainThreadExecutor)
-                .join();
-
-        return scheduler;
+        return SchedulerTestingUtils.createSchedulerAndDeploy(
+                false,
+                jobId,
+                v1,
+                new JobVertex[] {v2},
+                distributionPattern,
+                blobWriter,
+                mainThreadExecutor,
+                ioExecutor,
+                NoOpJobMasterPartitionTracker.INSTANCE,
+                EXECUTOR_RESOURCE.getExecutor());
     }
 
-    private void triggerGlobalFailoverAndComplete(DefaultScheduler scheduler, JobVertex upstream)
+    private void triggerGlobalFailoverAndComplete(SchedulerBase scheduler, JobVertex upstream)
             throws TimeoutException {
 
         final Throwable t = new Exception();
@@ -378,66 +353,6 @@ class RemoveCachedShuffleDescriptorTest {
 
     // ============== Utils ==============
 
-    private static DefaultScheduler createScheduler(
-            final JobID jobId,
-            final List<JobVertex> jobVertices,
-            final BlobWriter blobWriter,
-            final ComponentMainThreadExecutor mainThreadExecutor,
-            final ScheduledExecutorService ioExecutor)
-            throws Exception {
-        final JobGraph jobGraph =
-                JobGraphBuilder.newBatchJobGraphBuilder()
-                        .setJobId(jobId)
-                        .addJobVertices(jobVertices)
-                        .build();
-
-        return new DefaultSchedulerBuilder(
-                        jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor())
-                .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
-                .setBlobWriter(blobWriter)
-                .setIoExecutor(ioExecutor)
-                .build();
-    }
-
-    private static void deployTasks(
-            ExecutionGraph executionGraph,
-            JobVertexID jobVertexID,
-            TestingLogicalSlotBuilder slotBuilder)
-            throws JobException, ExecutionException, InterruptedException {
-
-        for (ExecutionVertex vertex :
-                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
-                        .getTaskVertices()) {
-            LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
-
-            Execution execution = vertex.getCurrentExecutionAttempt();
-            execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
-            execution.transitionState(ExecutionState.SCHEDULED);
-
-            vertex.tryAssignResource(slot);
-            vertex.deploy();
-        }
-    }
-
-    private static void transitionTasksToFinished(
-            ExecutionGraph executionGraph, JobVertexID jobVertexID) {
-
-        for (ExecutionVertex vertex :
-                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
-                        .getTaskVertices()) {
-            transitionTaskToFinished(executionGraph, vertex);
-        }
-    }
-
-    private static void transitionTaskToFinished(
-            ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
-        executionGraph.updateState(
-                new TaskExecutionStateTransition(
-                        new TaskExecutionState(
-                                executionVertex.getCurrentExecutionAttempt().getAttemptId(),
-                                ExecutionState.FINISHED)));
-    }
-
     private static MaybeOffloaded<ShuffleDescriptor[]> getConsumedCachedShuffleDescriptor(
             ExecutionGraph executionGraph, JobVertex vertex) {
         return getConsumedCachedShuffleDescriptor(executionGraph, vertex, 0);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
index 087a29613e1..130212bcf05 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
@@ -28,8 +28,9 @@ import java.util.Collections;
 import java.util.List;
 
 /** No-op implementation of {@link JobMasterPartitionTracker}. */
-public enum NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
-    INSTANCE;
+public class NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
+    public static final NoOpJobMasterPartitionTracker INSTANCE =
+            new NoOpJobMasterPartitionTracker();
 
     public static final PartitionTrackerFactory FACTORY = lookup -> INSTANCE;
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 66e097b34b1..cf59753999f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test;
 import java.io.IOException;
 import java.net.URL;
 import java.net.URLClassLoader;
+import java.util.List;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -41,6 +42,45 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 @SuppressWarnings("serial")
 class JobTaskVertexTest {
 
+    @Test
+    void testMultipleConsumersVertices() {
+        JobVertex producer = new JobVertex("producer");
+        JobVertex consumer1 = new JobVertex("consumer1");
+        JobVertex consumer2 = new JobVertex("consumer2");
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        consumer1.connectNewDataSetAsInput(
+                producer,
+                DistributionPattern.ALL_TO_ALL,
+                ResultPartitionType.BLOCKING,
+                dataSetId,
+                false);
+        consumer2.connectNewDataSetAsInput(
+                producer,
+                DistributionPattern.ALL_TO_ALL,
+                ResultPartitionType.BLOCKING,
+                dataSetId,
+                false);
+
+        JobVertex consumer3 = new JobVertex("consumer3");
+        consumer3.connectNewDataSetAsInput(
+                producer, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
+
+        assertThat(producer.getProducedDataSets()).hasSize(2);
+
+        IntermediateDataSet dataSet = producer.getProducedDataSets().get(0);
+        assertThat(dataSet.getId()).isEqualTo(dataSetId);
+
+        List<JobEdge> consumers1 = dataSet.getConsumers();
+        assertThat(consumers1).hasSize(2);
+        assertThat(consumers1.get(0).getTarget().getID()).isEqualTo(consumer1.getID());
+        assertThat(consumers1.get(1).getTarget().getID()).isEqualTo(consumer2.getID());
+
+        List<JobEdge> consumers2 = producer.getProducedDataSets().get(1).getConsumers();
+        assertThat(consumers2).hasSize(1);
+        assertThat(consumers2.get(0).getTarget().getID()).isEqualTo(consumer3.getID());
+    }
+
     @Test
     void testConnectDirectly() {
         JobVertex source = new JobVertex("source");
@@ -59,7 +99,8 @@ class JobTaskVertexTest {
         assertThat(source.getProducedDataSets().get(0))
                 .isEqualTo(target.getInputs().get(0).getSource());
 
-        assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target);
+        assertThat(source.getProducedDataSets().get(0).getConsumers().get(0).getTarget())
+                .isEqualTo(target);
     }
 
     @Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
index 7f9488fe901..a0ce9132e3f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
@@ -206,7 +206,8 @@ public class JobIntermediateDatasetReuseTest {
                 sender,
                 DistributionPattern.POINTWISE,
                 ResultPartitionType.BLOCKING_PERSISTENT,
-                intermediateDataSetID);
+                intermediateDataSetID,
+                false);
 
         return new JobGraph(null, "First Job", sender, receiver);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
index eb09a7e139c..b3954d2ea07 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
@@ -20,6 +20,8 @@ package org.apache.flink.runtime.scheduler;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.blob.BlobWriter;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
@@ -28,12 +30,24 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
 import org.apache.flink.runtime.checkpoint.PendingCheckpoint;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
+import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
+import org.apache.flink.runtime.jobmaster.LogicalSlot;
+import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
 import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProvider;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
@@ -46,12 +60,18 @@ import org.apache.flink.util.TernaryBoolean;
 import javax.annotation.Nullable;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 import java.util.stream.StreamSupport;
 
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
@@ -297,4 +317,121 @@ public class SchedulerTestingUtils {
                 allocationTimeout,
                 new LocalInputPreferredSlotSharingStrategy.Factory());
     }
+
+    public static SchedulerBase createSchedulerAndDeploy(
+            boolean isAdaptive,
+            JobID jobId,
+            JobVertex producer,
+            JobVertex[] consumers,
+            DistributionPattern distributionPattern,
+            BlobWriter blobWriter,
+            ComponentMainThreadExecutor mainThreadExecutor,
+            ScheduledExecutorService ioExecutor,
+            JobMasterPartitionTracker partitionTracker,
+            ScheduledExecutorService scheduledExecutor)
+            throws Exception {
+        final List<JobVertex> vertices = new ArrayList<>(Collections.singletonList(producer));
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        for (JobVertex consumer : consumers) {
+            consumer.connectNewDataSetAsInput(
+                    producer, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+            vertices.add(consumer);
+        }
+
+        final SchedulerBase scheduler =
+                createScheduler(
+                        isAdaptive,
+                        jobId,
+                        vertices,
+                        blobWriter,
+                        mainThreadExecutor,
+                        ioExecutor,
+                        partitionTracker,
+                        scheduledExecutor);
+        final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+        final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
+
+        CompletableFuture.runAsync(
+                        () -> {
+                            try {
+                                if (isAdaptive) {
+                                    initializeExecutionJobVertex(producer.getID(), executionGraph);
+                                }
+                                // Deploy upstream source vertices
+                                deployTasks(executionGraph, producer.getID(), slotBuilder);
+                                // Transition upstream vertices into FINISHED
+                                finishJobVertex(executionGraph, producer.getID());
+                                // Deploy downstream sink vertices
+                                for (JobVertex consumer : consumers) {
+                                    if (isAdaptive) {
+                                        initializeExecutionJobVertex(
+                                                consumer.getID(), executionGraph);
+                                    }
+                                    deployTasks(executionGraph, consumer.getID(), slotBuilder);
+                                }
+                            } catch (Exception e) {
+                                throw new RuntimeException("Exceptions shouldn't happen here.", e);
+                            }
+                        },
+                        mainThreadExecutor)
+                .join();
+        return scheduler;
+    }
+
+    private static void initializeExecutionJobVertex(
+            JobVertexID jobVertex, ExecutionGraph executionGraph) {
+        try {
+            executionGraph.initializeJobVertex(
+                    executionGraph.getJobVertex(jobVertex), System.currentTimeMillis());
+            executionGraph.notifyNewlyInitializedJobVertices(
+                    Collections.singletonList(executionGraph.getJobVertex(jobVertex)));
+        } catch (JobException exception) {
+            throw new RuntimeException(exception);
+        }
+    }
+
+    private static DefaultScheduler createScheduler(
+            boolean isAdaptive,
+            JobID jobId,
+            List<JobVertex> jobVertices,
+            BlobWriter blobWriter,
+            ComponentMainThreadExecutor mainThreadExecutor,
+            ScheduledExecutorService ioExecutor,
+            JobMasterPartitionTracker partitionTracker,
+            ScheduledExecutorService scheduledExecutor)
+            throws Exception {
+        final JobGraph jobGraph =
+                JobGraphBuilder.newBatchJobGraphBuilder()
+                        .setJobId(jobId)
+                        .addJobVertices(jobVertices)
+                        .build();
+
+        final DefaultSchedulerBuilder builder =
+                new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, scheduledExecutor)
+                        .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
+                        .setBlobWriter(blobWriter)
+                        .setIoExecutor(ioExecutor)
+                        .setPartitionTracker(partitionTracker);
+        return isAdaptive ? builder.buildAdaptiveBatchJobScheduler() : builder.build();
+    }
+
+    private static void deployTasks(
+            ExecutionGraph executionGraph,
+            JobVertexID jobVertexID,
+            TestingLogicalSlotBuilder slotBuilder)
+            throws JobException, ExecutionException, InterruptedException {
+
+        for (ExecutionVertex vertex :
+                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
+                        .getTaskVertices()) {
+            LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
+
+            Execution execution = vertex.getCurrentExecutionAttempt();
+            execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
+            execution.transitionState(ExecutionState.SCHEDULED);
+
+            vertex.tryAssignResource(slot);
+            vertex.deploy();
+        }
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
index 4e1d9cb76a4..f22c3b23dd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
@@ -52,7 +52,6 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
@@ -314,17 +313,23 @@ class DefaultExecutionTopologyTest {
 
             assertPartitionEquals(originalPartition, adaptedPartition);
 
-            ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup();
-            Optional<ConsumerVertexGroup> adaptedConsumers =
-                    adaptedPartition.getConsumerVertexGroup();
-            assertThat(adaptedConsumers).isPresent();
-            for (ExecutionVertexID originalId : consumerVertexGroup) {
+            List<ExecutionVertexID> originalConsumerIds = new ArrayList<>();
+            for (ConsumerVertexGroup consumerVertexGroup :
+                    originalPartition.getConsumerVertexGroups()) {
+                for (ExecutionVertexID executionVertexId : consumerVertexGroup) {
+                    originalConsumerIds.add(executionVertexId);
+                }
+            }
+            List<ConsumerVertexGroup> adaptedConsumers = adaptedPartition.getConsumerVertexGroups();
+            assertThat(adaptedConsumers).isNotEmpty();
+            for (ExecutionVertexID originalId : originalConsumerIds) {
                 // it is sufficient to verify that some vertex exists with the correct ID here,
                 // since deep equality is verified later in the main loop
                 // this DOES rely on an implicit assumption that the vertices objects returned by
                 // the topology are
                 // identical to those stored in the partition
-                assertThat(adaptedConsumers.get()).contains(originalId);
+                assertThat(adaptedConsumers.stream().flatMap(IterableUtils::toStream))
+                        .contains(originalId);
             }
         }
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
index fc2df697464..d6fb9ace1a3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
@@ -81,6 +81,7 @@ class DefaultExecutionVertexTest {
         List<ConsumedPartitionGroup> consumedPartitionGroups =
                 Collections.singletonList(
                         ConsumedPartitionGroup.fromSinglePartition(
+                                1,
                                 intermediateResultPartitionId,
                                 schedulingResultPartition.getResultType()));
         Map<IntermediateResultPartitionID, DefaultResultPartition> resultPartitionById =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
index 6d0024626c1..9f8c58400e3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
@@ -28,7 +28,10 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Supplier;
 
@@ -47,8 +50,8 @@ class DefaultResultPartitionTest {
 
     private DefaultResultPartition resultPartition;
 
-    private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> consumerVertexGroups =
-            new HashMap<>();
+    private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+            consumerVertexGroups = new HashMap<>();
 
     @BeforeEach
     void setUp() {
@@ -58,7 +61,9 @@ class DefaultResultPartitionTest {
                         intermediateResultId,
                         BLOCKING,
                         resultPartitionState,
-                        () -> consumerVertexGroups.get(resultPartitionId),
+                        () ->
+                                consumerVertexGroups.computeIfAbsent(
+                                        resultPartitionId, ignored -> new ArrayList<>()),
                         () -> {
                             throw new UnsupportedOperationException();
                         });
@@ -75,14 +80,15 @@ class DefaultResultPartitionTest {
     @Test
     void testGetConsumerVertexGroup() {
 
-        assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent();
+        assertThat(resultPartition.getConsumerVertexGroups()).isEmpty();
 
         // test update consumers
         ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
         consumerVertexGroups.put(
-                resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId));
-        assertThat(resultPartition.getConsumerVertexGroup()).isPresent();
-        assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId);
+                resultPartition.getId(),
+                Collections.singletonList(ConsumerVertexGroup.fromSingleVertex(executionVertexId)));
+        assertThat(resultPartition.getConsumerVertexGroups()).isNotEmpty();
+        assertThat(resultPartition.getConsumerVertexGroups().get(0)).contains(executionVertexId);
     }
 
     /** A test {@link ResultPartitionState} supplier. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index 3c9bed61354..a6e87f9bfc2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -166,7 +166,7 @@ class AdaptiveBatchSchedulerTest {
         sink.connectNewDataSetAsInput(
                 source2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
         if (withForwardEdge) {
-            source1.getProducedDataSets().get(0).getConsumer().setForward(true);
+            source1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
         return new JobGraph(new JobID(), "test job", source1, source2, sink);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
index b45ba3f5ebe..e9b0da1509d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -97,14 +97,14 @@ class ForwardGroupComputeUtilTest {
         v2.connectNewDataSetAsInput(
                 v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
         if (isForward1) {
-            v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+            v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
 
         v3.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
 
         if (isForward2) {
-            v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+            v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
 
         Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3);
@@ -135,10 +135,10 @@ class ForwardGroupComputeUtilTest {
 
         v3.connectNewDataSetAsInput(
                 v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
-        v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+        v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         v3.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
-        v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+        v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         v4.connectNewDataSetAsInput(
                 v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
 
@@ -174,8 +174,8 @@ class ForwardGroupComputeUtilTest {
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
         v4.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
-        v2.getProducedDataSets().get(0).getConsumer().setForward(true);
-        v2.getProducedDataSets().get(1).getConsumer().setForward(true);
+        v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
+        v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true);
 
         Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
index 4621680d360..655f725874c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
@@ -93,7 +93,9 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
     void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) {
         final ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromSinglePartition(
-                        consumedPartition.getId(), consumedPartition.getResultType());
+                        consumedPartition.getNumConsumers(),
+                        consumedPartition.getId(),
+                        consumedPartition.getResultType());
 
         consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
         if (consumedPartition.getState() == ResultPartitionState.CONSUMABLE) {
@@ -155,7 +157,8 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
                     partitionIds.add(partitionId);
                 }
                 this.consumedPartitionGroups.add(
-                        ConsumedPartitionGroup.fromMultiplePartitions(partitionIds, resultType));
+                        ConsumedPartitionGroup.fromMultiplePartitions(
+                                partitionGroup.getNumConsumers(), partitionIds, resultType));
             }
             return this;
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
index 9454268cb5f..6274eecb285 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
@@ -28,7 +28,6 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -64,6 +63,10 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
         this.consumedPartitionGroups = new ArrayList<>();
     }
 
+    public int getNumConsumers() {
+        return consumerVertexGroup == null ? 1 : consumerVertexGroup.size();
+    }
+
     @Override
     public IntermediateResultPartitionID getId() {
         return intermediateResultPartitionID;
@@ -90,8 +93,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
     }
 
     @Override
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
-        return Optional.of(consumerVertexGroup);
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return Collections.singletonList(consumerVertexGroup);
     }
 
     @Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
index 82463a1a2cb..95596369e33 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
@@ -350,6 +350,7 @@ public class TestingSchedulingTopology implements SchedulingTopology {
 
             ConsumedPartitionGroup consumedPartitionGroup =
                     ConsumedPartitionGroup.fromMultiplePartitions(
+                            consumers.size(),
                             resultPartitions.stream()
                                     .map(TestingSchedulingResultPartition::getId)
                                     .collect(Collectors.toList()),
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index be72cacff5c..4cf9ed7c990 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -1079,19 +1079,20 @@ public class StreamingJobGraphGenerator {
                             headVertex,
                             DistributionPattern.POINTWISE,
                             resultPartitionType,
-                            intermediateDataSetID);
+                            intermediateDataSetID,
+                            partitioner.isBroadcast());
         } else {
             jobEdge =
                     downStreamVertex.connectNewDataSetAsInput(
                             headVertex,
                             DistributionPattern.ALL_TO_ALL,
                             resultPartitionType,
-                            intermediateDataSetID);
+                            intermediateDataSetID,
+                            partitioner.isBroadcast());
         }
 
         // set strategy name so that web interface can show it.
         jobEdge.setShipStrategyName(partitioner.toString());
-        jobEdge.setBroadcast(partitioner.isBroadcast());
         jobEdge.setForward(partitioner instanceof ForwardPartitioner);
         jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper());
         jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper());
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
index ab50814edd2..137e5040383 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
@@ -55,24 +55,26 @@ public class TestJobDataFlowValidator {
 
         for (JobVertex upstream : testJob.jobGraph.getVertices()) {
             for (IntermediateDataSet produced : upstream.getProducedDataSets()) {
-                JobEdge edge = produced.getConsumer();
-                Optional<String> upstreamIDOptional = getTrackedOperatorID(upstream, true, testJob);
-                Optional<String> downstreamIDOptional =
-                        getTrackedOperatorID(edge.getTarget(), false, testJob);
-                if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
-                    final String upstreamID = upstreamIDOptional.get();
-                    final String downstreamID = downstreamIDOptional.get();
-                    if (testJob.sources.contains(upstreamID)) {
-                        // TODO: if we add tests for FLIP-27 sources we might need to adjust
-                        // this condition
-                        LOG.debug(
-                                "Legacy sources do not have the finish() method and thus do not"
-                                        + " emit FinishEvent");
+                for (JobEdge edge : produced.getConsumers()) {
+                    Optional<String> upstreamIDOptional =
+                            getTrackedOperatorID(upstream, true, testJob);
+                    Optional<String> downstreamIDOptional =
+                            getTrackedOperatorID(edge.getTarget(), false, testJob);
+                    if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
+                        final String upstreamID = upstreamIDOptional.get();
+                        final String downstreamID = downstreamIDOptional.get();
+                        if (testJob.sources.contains(upstreamID)) {
+                            // TODO: if we add tests for FLIP-27 sources we might need to adjust
+                            // this condition
+                            LOG.debug(
+                                    "Legacy sources do not have the finish() method and thus do not"
+                                            + " emit FinishEvent");
+                        } else {
+                            checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+                        }
                     } else {
-                        checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+                        LOG.debug("Ignoring edge (untracked operator): {}", edge);
                     }
-                } else {
-                    LOG.debug("Ignoring edge (untracked operator): {}", edge);
                 }
             }
         }