You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/08/08 14:58:40 UTC

[flink] branch master updated (413912d7fd3 -> 72405361610)

This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


    from 413912d7fd3 [hotfix][python] Make the format imports more explicit by adding format type
     new b3be6bbd9c9 [hotfix][tests] Migrate tests relevant to FLINK-28663 to Junit5/AssertJ
     new 72405361610 [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side

The 2 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../optimizer/plantranslate/JobGraphGenerator.java |   5 +-
 .../TaskDeploymentDescriptorFactory.java           |   9 +-
 .../executiongraph/DefaultExecutionGraph.java      |  15 +-
 .../flink/runtime/executiongraph/EdgeManager.java  |  24 +-
 .../executiongraph/EdgeManagerBuildUtil.java       |  21 +-
 .../flink/runtime/executiongraph/Execution.java    |  61 ++---
 .../runtime/executiongraph/IntermediateResult.java |  70 +++++-
 .../IntermediateResultPartition.java               |  58 +++--
 .../RestartPipelinedRegionFailoverStrategy.java    |  12 +-
 .../SchedulingPipelinedRegionComputeUtil.java      |  32 ++-
 .../io/network/NettyShuffleEnvironment.java        |  36 ++-
 .../runtime/jobgraph/IntermediateDataSet.java      |  40 ++-
 .../org/apache/flink/runtime/jobgraph/JobEdge.java |  48 +---
 .../apache/flink/runtime/jobgraph/JobGraph.java    |   5 +-
 .../apache/flink/runtime/jobgraph/JobVertex.java   |  41 ++--
 .../SsgNetworkMemoryCalculationUtils.java          |  30 ++-
 .../adapter/DefaultExecutionTopology.java          |  10 +-
 .../scheduler/adapter/DefaultResultPartition.java  |  11 +-
 .../scheduler/strategy/ConsumedPartitionGroup.java |  19 +-
 .../strategy/SchedulingResultPartition.java        |   7 +-
 .../strategy/VertexwiseSchedulingStrategy.java     |  10 +-
 .../BlockingResultPartitionReleaseTest.java        | 151 ++++++++++++
 .../DefaultExecutionGraphConstructionTest.java     | 270 ++++++++++++---------
 .../executiongraph/EdgeManagerBuildUtilTest.java   |  30 +--
 .../runtime/executiongraph/EdgeManagerTest.java    | 112 ++++++---
 .../executiongraph/ExecutionGraphTestUtils.java    |  19 ++
 .../executiongraph/ExecutionJobVertexTest.java     | 130 ++++------
 .../IntermediateResultPartitionTest.java           | 186 ++++++++------
 .../RemoveCachedShuffleDescriptorTest.java         | 215 +++++-----------
 .../partition/NoOpJobMasterPartitionTracker.java   |   5 +-
 .../flink/runtime/jobgraph/JobTaskVertexTest.java  | 195 ++++++++-------
 .../jobmaster/JobIntermediateDatasetReuseTest.java |   3 +-
 .../runtime/scheduler/SchedulerTestingUtils.java   | 137 +++++++++++
 .../adapter/DefaultExecutionTopologyTest.java      | 143 +++++------
 .../adapter/DefaultExecutionVertexTest.java        |  26 +-
 .../adapter/DefaultResultPartitionTest.java        |  43 ++--
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |  40 ++-
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  49 ++--
 .../strategy/TestingSchedulingExecutionVertex.java |   7 +-
 .../strategy/TestingSchedulingResultPartition.java |   9 +-
 .../strategy/TestingSchedulingTopology.java        |   1 +
 .../api/graph/StreamingJobGraphGenerator.java      |   7 +-
 .../validation/TestJobDataFlowValidator.java       |  34 +--
 43 files changed, 1414 insertions(+), 962 deletions(-)
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java


[flink] 02/02: [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side

Posted by yi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 72405361610f1576e0f57ea4e957fafbbbaf0910
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Mon Jul 25 16:03:28 2022 +0800

    [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side
    
    This closes #20350.
---
 .../optimizer/plantranslate/JobGraphGenerator.java |   5 +-
 .../TaskDeploymentDescriptorFactory.java           |   9 +-
 .../executiongraph/DefaultExecutionGraph.java      |  15 +-
 .../flink/runtime/executiongraph/EdgeManager.java  |  24 +++-
 .../executiongraph/EdgeManagerBuildUtil.java       |  21 +--
 .../flink/runtime/executiongraph/Execution.java    |  61 +++++----
 .../runtime/executiongraph/IntermediateResult.java |  70 ++++++++--
 .../IntermediateResultPartition.java               |  58 +++++---
 .../RestartPipelinedRegionFailoverStrategy.java    |  12 +-
 .../SchedulingPipelinedRegionComputeUtil.java      |  32 ++---
 .../io/network/NettyShuffleEnvironment.java        |  36 +++--
 .../runtime/jobgraph/IntermediateDataSet.java      |  40 ++++--
 .../org/apache/flink/runtime/jobgraph/JobEdge.java |  48 ++-----
 .../apache/flink/runtime/jobgraph/JobGraph.java    |   5 +-
 .../apache/flink/runtime/jobgraph/JobVertex.java   |  41 +++---
 .../SsgNetworkMemoryCalculationUtils.java          |  30 ++--
 .../adapter/DefaultExecutionTopology.java          |  10 +-
 .../scheduler/adapter/DefaultResultPartition.java  |  11 +-
 .../scheduler/strategy/ConsumedPartitionGroup.java |  19 ++-
 .../strategy/SchedulingResultPartition.java        |   7 +-
 .../strategy/VertexwiseSchedulingStrategy.java     |  10 +-
 .../BlockingResultPartitionReleaseTest.java        | 151 +++++++++++++++++++++
 .../DefaultExecutionGraphConstructionTest.java     |  41 ++++++
 .../executiongraph/EdgeManagerBuildUtilTest.java   |   2 +-
 .../runtime/executiongraph/EdgeManagerTest.java    |  88 +++++++++---
 .../executiongraph/ExecutionGraphTestUtils.java    |  19 +++
 .../executiongraph/ExecutionJobVertexTest.java     |   2 +-
 .../IntermediateResultPartitionTest.java           |  62 ++++++++-
 .../RemoveCachedShuffleDescriptorTest.java         | 133 ++++--------------
 .../partition/NoOpJobMasterPartitionTracker.java   |   5 +-
 .../flink/runtime/jobgraph/JobTaskVertexTest.java  |  43 +++++-
 .../jobmaster/JobIntermediateDatasetReuseTest.java |   3 +-
 .../runtime/scheduler/SchedulerTestingUtils.java   | 137 +++++++++++++++++++
 .../adapter/DefaultExecutionTopologyTest.java      |  19 ++-
 .../adapter/DefaultExecutionVertexTest.java        |   1 +
 .../adapter/DefaultResultPartitionTest.java        |  20 ++-
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |   2 +-
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  12 +-
 .../strategy/TestingSchedulingExecutionVertex.java |   7 +-
 .../strategy/TestingSchedulingResultPartition.java |   9 +-
 .../strategy/TestingSchedulingTopology.java        |   1 +
 .../api/graph/StreamingJobGraphGenerator.java      |   7 +-
 .../validation/TestJobDataFlowValidator.java       |  34 ++---
 43 files changed, 958 insertions(+), 404 deletions(-)

diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
index 0da624f1e7f..887da47ed0c 100644
--- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
+++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
@@ -1246,7 +1246,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
                 predecessorVertex != null,
                 "Bug: Chained task has not been assigned its containing vertex when connecting.");
 
-        predecessorVertex.createAndAddResultDataSet(
+        predecessorVertex.getOrCreateResultDataSet(
                 // use specified intermediateDataSetID
                 new IntermediateDataSetID(
                         ((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()),
@@ -1326,7 +1326,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
 
         JobEdge edge =
                 targetVertex.connectNewDataSetAsInput(
-                        sourceVertex, distributionPattern, resultType);
+                        sourceVertex, distributionPattern, resultType, isBroadcast);
 
         // -------------- configure the source task's ship strategy strategies in task config
         // --------------
@@ -1403,7 +1403,6 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
                 channel.getTempMode() == TempMode.NONE ? null : channel.getTempMode().toString();
 
         edge.setShipStrategyName(shipStrategy);
-        edge.setBroadcast(isBroadcast);
         edge.setForward(channel.getShipStrategy() == ShipStrategyType.FORWARD);
         edge.setPreProcessingOperationName(localStrategy);
         edge.setOperatorLevelCachingDescription(caching);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index a2c45ed0c47..ba10dd59de8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -133,7 +133,9 @@ public class TaskDeploymentDescriptorFactory {
             IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
             SubpartitionIndexRange consumedSubpartitionRange =
                     computeConsumedSubpartitionRange(
-                            resultPartition, executionId.getSubtaskIndex());
+                            consumedPartitionGroup.getNumConsumers(),
+                            resultPartition,
+                            executionId.getSubtaskIndex());
 
             IntermediateDataSetID resultId = consumedIntermediateResult.getId();
             ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
@@ -164,8 +166,9 @@ public class TaskDeploymentDescriptorFactory {
     }
 
     public static SubpartitionIndexRange computeConsumedSubpartitionRange(
-            IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
-        int numConsumers = resultPartition.getConsumerVertexGroup().size();
+            int numConsumers,
+            IntermediateResultPartition resultPartition,
+            int consumerSubtaskIndex) {
         int consumerIndex = consumerSubtaskIndex % numConsumers;
         IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
         int numSubpartitions = resultPartition.getNumberOfSubpartitions();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 2f160f14605..ce1f18c5504 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -1399,6 +1399,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
             final List<ConsumedPartitionGroup> releasablePartitionGroups) {
 
         if (releasablePartitionGroups.size() > 0) {
+            final List<ResultPartitionID> releasablePartitionIds = new ArrayList<>();
 
             // Remove the cache of ShuffleDescriptors when ConsumedPartitionGroups are released
             for (ConsumedPartitionGroup releasablePartitionGroup : releasablePartitionGroups) {
@@ -1406,15 +1407,17 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
                         checkNotNull(
                                 intermediateResults.get(
                                         releasablePartitionGroup.getIntermediateDataSetID()));
+                for (IntermediateResultPartitionID partitionId : releasablePartitionGroup) {
+                    IntermediateResultPartition partition =
+                            totalResult.getPartitionById(partitionId);
+                    partition.markPartitionGroupReleasable(releasablePartitionGroup);
+                    if (partition.canBeReleased()) {
+                        releasablePartitionIds.add(createResultPartitionId(partitionId));
+                    }
+                }
                 totalResult.clearCachedInformationForPartitionGroup(releasablePartitionGroup);
             }
 
-            final List<ResultPartitionID> releasablePartitionIds =
-                    releasablePartitionGroups.stream()
-                            .flatMap(IterableUtils::toStream)
-                            .map(this::createResultPartitionId)
-                            .collect(Collectors.toList());
-
             partitionTracker.stopTrackingAndReleasePartitions(releasablePartitionIds);
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
index 8efc25a913b..1a437e16030 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
@@ -30,12 +30,11 @@ import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
-import static org.apache.flink.util.Preconditions.checkState;
 
 /** Class that manages all the connections between tasks. */
 public class EdgeManager {
 
-    private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> partitionConsumers =
+    private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>> partitionConsumers =
             new HashMap<>();
 
     private final Map<ExecutionVertexID, List<ConsumedPartitionGroup>> vertexConsumedPartitions =
@@ -50,9 +49,9 @@ public class EdgeManager {
 
         checkNotNull(consumerVertexGroup);
 
-        checkState(
-                partitionConsumers.putIfAbsent(resultPartitionId, consumerVertexGroup) == null,
-                "Currently one partition can have at most one consumer group");
+        List<ConsumerVertexGroup> groups =
+                getConsumerVertexGroupsForPartitionInternal(resultPartitionId);
+        groups.add(consumerVertexGroup);
     }
 
     public void connectVertexWithConsumedPartitionGroup(
@@ -66,14 +65,20 @@ public class EdgeManager {
         consumedPartitions.add(consumedPartitionGroup);
     }
 
+    private List<ConsumerVertexGroup> getConsumerVertexGroupsForPartitionInternal(
+            IntermediateResultPartitionID resultPartitionId) {
+        return partitionConsumers.computeIfAbsent(resultPartitionId, id -> new ArrayList<>());
+    }
+
     private List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertexInternal(
             ExecutionVertexID executionVertexId) {
         return vertexConsumedPartitions.computeIfAbsent(executionVertexId, id -> new ArrayList<>());
     }
 
-    public ConsumerVertexGroup getConsumerVertexGroupForPartition(
+    public List<ConsumerVertexGroup> getConsumerVertexGroupsForPartition(
             IntermediateResultPartitionID resultPartitionId) {
-        return partitionConsumers.get(resultPartitionId);
+        return Collections.unmodifiableList(
+                getConsumerVertexGroupsForPartitionInternal(resultPartitionId));
     }
 
     public List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertex(
@@ -100,4 +105,9 @@ public class EdgeManager {
         return Collections.unmodifiableList(
                 getConsumedPartitionGroupsByIdInternal(resultPartitionId));
     }
+
+    public int getNumberOfConsumedPartitionGroupsById(
+            IntermediateResultPartitionID resultPartitionId) {
+        return getConsumedPartitionGroupsByIdInternal(resultPartitionId).size();
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index 3aa6e59de0d..8ac55b7b7a2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -89,7 +89,7 @@ public class EdgeManagerBuildUtil {
                         .collect(Collectors.toList());
         ConsumedPartitionGroup consumedPartitionGroup =
                 createAndRegisterConsumedPartitionGroupToEdgeManager(
-                        consumedPartitions, intermediateResult);
+                        taskVertices.length, consumedPartitions, intermediateResult);
         for (ExecutionVertex ev : taskVertices) {
             ev.addConsumedPartitionGroup(consumedPartitionGroup);
         }
@@ -122,7 +122,9 @@ public class EdgeManagerBuildUtil {
 
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                partition.getPartitionId(), intermediateResult);
+                                consumerVertexGroup.size(),
+                                partition.getPartitionId(),
+                                intermediateResult);
                 executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
             }
         } else if (sourceCount > targetCount) {
@@ -147,20 +149,19 @@ public class EdgeManagerBuildUtil {
 
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                consumedPartitions, intermediateResult);
+                                consumerVertexGroup.size(), consumedPartitions, intermediateResult);
                 executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
             }
         } else {
             for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {
+                int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
+                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
 
                 IntermediateResultPartition partition =
                         intermediateResult.getPartitions()[partitionNum];
                 ConsumedPartitionGroup consumedPartitionGroup =
                         createAndRegisterConsumedPartitionGroupToEdgeManager(
-                                partition.getPartitionId(), intermediateResult);
-
-                int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
-                int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
+                                end - start, partition.getPartitionId(), intermediateResult);
 
                 List<ExecutionVertexID> consumers = new ArrayList<>(end - start);
 
@@ -179,21 +180,23 @@ public class EdgeManagerBuildUtil {
     }
 
     private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+            int numConsumers,
             IntermediateResultPartitionID consumedPartitionId,
             IntermediateResult intermediateResult) {
         ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromSinglePartition(
-                        consumedPartitionId, intermediateResult.getResultType());
+                        numConsumers, consumedPartitionId, intermediateResult.getResultType());
         registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
         return consumedPartitionGroup;
     }
 
     private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+            int numConsumers,
             List<IntermediateResultPartitionID> consumedPartitions,
             IntermediateResult intermediateResult) {
         ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromMultiplePartitions(
-                        consumedPartitions, intermediateResult.getResultType());
+                        numConsumers, consumedPartitions, intermediateResult.getResultType());
         registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
         return consumedPartitionGroup;
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 657effb66f6..f6de9059869 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -65,6 +65,7 @@ import javax.annotation.Nullable;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
@@ -497,10 +498,7 @@ public class Execution
     }
 
     private static int getPartitionMaxParallelism(IntermediateResultPartition partition) {
-        return partition
-                .getIntermediateResult()
-                .getConsumerExecutionJobVertex()
-                .getMaxParallelism();
+        return partition.getIntermediateResult().getConsumersMaxParallelism();
     }
 
     /**
@@ -718,31 +716,40 @@ public class Execution
     }
 
     private void updatePartitionConsumers(final IntermediateResultPartition partition) {
-        final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                partition.getConsumerVertexGroupOptional();
-        if (!consumerVertexGroup.isPresent()) {
+        final List<ConsumerVertexGroup> consumerVertexGroups = partition.getConsumerVertexGroups();
+        if (consumerVertexGroups.isEmpty()) {
             return;
         }
-        for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
-            final ExecutionVertex consumerVertex =
-                    vertex.getExecutionGraphAccessor().getExecutionVertexOrThrow(consumerVertexId);
-            final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
-            final ExecutionState consumerState = consumer.getState();
-
-            // ----------------------------------------------------------------
-            // Consumer is recovering or running => send update message now
-            // Consumer is deploying => cache the partition info which would be
-            // sent after switching to running
-            // ----------------------------------------------------------------
-            if (consumerState == DEPLOYING
-                    || consumerState == RUNNING
-                    || consumerState == INITIALIZING) {
-                final PartitionInfo partitionInfo = createPartitionInfo(partition);
-
-                if (consumerState == DEPLOYING) {
-                    consumerVertex.cachePartitionInfo(partitionInfo);
-                } else {
-                    consumer.sendUpdatePartitionInfoRpcCall(Collections.singleton(partitionInfo));
+        final Set<ExecutionVertexID> updatedVertices = new HashSet<>();
+        for (ConsumerVertexGroup consumerVertexGroup : consumerVertexGroups) {
+            for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+                if (updatedVertices.contains(consumerVertexId)) {
+                    continue;
+                }
+
+                final ExecutionVertex consumerVertex =
+                        vertex.getExecutionGraphAccessor()
+                                .getExecutionVertexOrThrow(consumerVertexId);
+                final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
+                final ExecutionState consumerState = consumer.getState();
+
+                // ----------------------------------------------------------------
+                // Consumer is recovering or running => send update message now
+                // Consumer is deploying => cache the partition info which would be
+                // sent after switching to running
+                // ----------------------------------------------------------------
+                if (consumerState == DEPLOYING
+                        || consumerState == RUNNING
+                        || consumerState == INITIALIZING) {
+                    final PartitionInfo partitionInfo = createPartitionInfo(partition);
+                    updatedVertices.add(consumerVertexId);
+
+                    if (consumerState == DEPLOYING) {
+                        consumerVertex.cachePartitionInfo(partitionInfo);
+                    } else {
+                        consumer.sendUpdatePartitionInfoRpcCall(
+                                Collections.singleton(partitionInfo));
+                    }
                 }
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index 4b666b5742b..fd8be042f6f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -32,12 +32,15 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 public class IntermediateResult {
 
@@ -68,6 +71,9 @@ public class IntermediateResult {
     private final Map<ConsumedPartitionGroup, MaybeOffloaded<ShuffleDescriptor[]>>
             shuffleDescriptorCache;
 
+    /** All consumer job vertex ids of this dataset. */
+    private final List<JobVertexID> consumerVertices = new ArrayList<>();
+
     public IntermediateResult(
             IntermediateDataSet intermediateDataSet,
             ExecutionJobVertex producer,
@@ -95,6 +101,10 @@ public class IntermediateResult {
         this.resultType = checkNotNull(resultType);
 
         this.shuffleDescriptorCache = new HashMap<>();
+
+        intermediateDataSet
+                .getConsumers()
+                .forEach(jobEdge -> consumerVertices.add(jobEdge.getTarget().getID()));
     }
 
     public void setPartition(int partitionNumber, IntermediateResultPartition partition) {
@@ -124,6 +134,10 @@ public class IntermediateResult {
         return partitions;
     }
 
+    public List<JobVertexID> getConsumerVertices() {
+        return consumerVertices;
+    }
+
     /**
      * Returns the partition with the given ID.
      *
@@ -162,20 +176,60 @@ public class IntermediateResult {
         return numParallelProducers;
     }
 
-    ExecutionJobVertex getConsumerExecutionJobVertex() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        final JobVertexID consumerJobVertexId = consumer.getTarget().getID();
-        return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId));
+    /**
+     * Currently, this method is only used to compute the maximum number of consumers. For dynamic
+     * graph, it should be called before adaptively deciding the downstream consumer parallelism.
+     */
+    int getConsumersParallelism() {
+        List<JobEdge> consumers = intermediateDataSet.getConsumers();
+        checkState(!consumers.isEmpty());
+
+        InternalExecutionGraphAccessor graph = getProducer().getGraph();
+        int consumersParallelism =
+                graph.getJobVertex(consumers.get(0).getTarget().getID()).getParallelism();
+        if (consumers.size() == 1) {
+            return consumersParallelism;
+        }
+
+        // sanity check, all consumer vertices must have the same parallelism:
+        // 1. for vertices that are not assigned a parallelism initially (for example, dynamic
+        // graph), the parallelisms will all be -1 (parallelism not decided yet)
+        // 2. for vertices that are initially assigned a parallelism, the parallelisms must be the
+        // same, which is guaranteed at compilation phase
+        for (JobVertexID jobVertexID : consumerVertices) {
+            checkState(
+                    consumersParallelism == graph.getJobVertex(jobVertexID).getParallelism(),
+                    "Consumers must have the same parallelism.");
+        }
+        return consumersParallelism;
+    }
+
+    int getConsumersMaxParallelism() {
+        List<JobEdge> consumers = intermediateDataSet.getConsumers();
+        checkState(!consumers.isEmpty());
+
+        InternalExecutionGraphAccessor graph = getProducer().getGraph();
+        int consumersMaxParallelism =
+                graph.getJobVertex(consumers.get(0).getTarget().getID()).getMaxParallelism();
+        if (consumers.size() == 1) {
+            return consumersMaxParallelism;
+        }
+
+        // sanity check, all consumer vertices must have the same max parallelism
+        for (JobVertexID jobVertexID : consumerVertices) {
+            checkState(
+                    consumersMaxParallelism == graph.getJobVertex(jobVertexID).getMaxParallelism(),
+                    "Consumers must have the same max parallelism.");
+        }
+        return consumersMaxParallelism;
     }
 
     public DistributionPattern getConsumingDistributionPattern() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        return consumer.getDistributionPattern();
+        return intermediateDataSet.getDistributionPattern();
     }
 
     public boolean isBroadcast() {
-        final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
-        return consumer.isBroadcast();
+        return intermediateDataSet.isBroadcast();
     }
 
     public int getConnectionIndex() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
index 5aef1b07471..9b9c176a3d9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
@@ -21,11 +21,13 @@ package org.apache.flink.runtime.executiongraph;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 
+import java.util.HashSet;
 import java.util.List;
-import java.util.Optional;
+import java.util.Set;
 
 import static org.apache.flink.util.Preconditions.checkState;
 
@@ -47,6 +49,12 @@ public class IntermediateResultPartition {
     /** Whether this partition has produced some data. */
     private boolean hasDataProduced = false;
 
+    /**
+     * Releasable {@link ConsumedPartitionGroup}s for this result partition. This result partition
+     * can be released if all {@link ConsumedPartitionGroup}s are releasable.
+     */
+    private final Set<ConsumedPartitionGroup> releasablePartitionGroups = new HashSet<>();
+
     public IntermediateResultPartition(
             IntermediateResult totalResult,
             ExecutionVertex producer,
@@ -58,6 +66,25 @@ public class IntermediateResultPartition {
         this.edgeManager = edgeManager;
     }
 
+    public void markPartitionGroupReleasable(ConsumedPartitionGroup partitionGroup) {
+        releasablePartitionGroups.add(partitionGroup);
+    }
+
+    public boolean canBeReleased() {
+        if (releasablePartitionGroups.size()
+                != edgeManager.getNumberOfConsumedPartitionGroupsById(partitionId)) {
+            return false;
+        }
+        for (JobVertexID jobVertexId : totalResult.getConsumerVertices()) {
+            // for dynamic graph, if any consumer vertex is still not initialized, this result
+            // partition can not be released
+            if (!producer.getExecutionGraphAccessor().getJobVertex(jobVertexId).isInitialized()) {
+                return false;
+            }
+        }
+        return true;
+    }
+
     public ExecutionVertex getProducer() {
         return producer;
     }
@@ -78,15 +105,8 @@ public class IntermediateResultPartition {
         return totalResult.getResultType();
     }
 
-    public ConsumerVertexGroup getConsumerVertexGroup() {
-        Optional<ConsumerVertexGroup> consumerVertexGroup = getConsumerVertexGroupOptional();
-        checkState(consumerVertexGroup.isPresent());
-        return consumerVertexGroup.get();
-    }
-
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroupOptional() {
-        return Optional.ofNullable(
-                getEdgeManager().getConsumerVertexGroupForPartition(partitionId));
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return getEdgeManager().getConsumerVertexGroupsForPartition(partitionId);
     }
 
     public List<ConsumedPartitionGroup> getConsumedPartitionGroups() {
@@ -106,12 +126,13 @@ public class IntermediateResultPartition {
 
     private int computeNumberOfSubpartitions() {
         if (!getProducer().getExecutionGraphAccessor().isDynamic()) {
-            ConsumerVertexGroup consumerVertexGroup = getConsumerVertexGroup();
-            checkState(consumerVertexGroup.size() > 0);
+            List<ConsumerVertexGroup> consumerVertexGroups = getConsumerVertexGroups();
+            checkState(!consumerVertexGroups.isEmpty());
 
             // The produced data is partitioned among a number of subpartitions, one for each
-            // consuming sub task.
-            return consumerVertexGroup.size();
+            // consuming sub task. All vertex groups must have the same number of consumers
+            // for non-dynamic graph.
+            return consumerVertexGroups.get(0).size();
         } else {
             if (totalResult.isBroadcast()) {
                 // for dynamic graph and broadcast result, we only produced one subpartition,
@@ -124,18 +145,16 @@ public class IntermediateResultPartition {
     }
 
     private int computeNumberOfMaxPossiblePartitionConsumers() {
-        final ExecutionJobVertex consumerJobVertex =
-                getIntermediateResult().getConsumerExecutionJobVertex();
         final DistributionPattern distributionPattern =
                 getIntermediateResult().getConsumingDistributionPattern();
 
         // decide the max possible consumer job vertex parallelism
-        int maxConsumerJobVertexParallelism = consumerJobVertex.getParallelism();
+        int maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersParallelism();
         if (maxConsumerJobVertexParallelism <= 0) {
+            maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersMaxParallelism();
             checkState(
-                    consumerJobVertex.getMaxParallelism() > 0,
+                    maxConsumerJobVertexParallelism > 0,
                     "Neither the parallelism nor the max parallelism of a job vertex is set");
-            maxConsumerJobVertexParallelism = consumerJobVertex.getMaxParallelism();
         }
 
         // compute number of subpartitions according to the distribution pattern
@@ -163,6 +182,7 @@ public class IntermediateResultPartition {
                 consumedPartitionGroup.partitionUnfinished();
             }
         }
+        releasablePartitionGroups.clear();
         hasDataProduced = false;
         for (ConsumedPartitionGroup consumedPartitionGroup : getConsumedPartitionGroups()) {
             totalResult.clearCachedInformationForPartitionGroup(consumedPartitionGroup);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
index 6e7181ffd03..39d7fe72547 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
@@ -230,12 +230,12 @@ public class RestartPipelinedRegionFailoverStrategy implements FailoverStrategy
 
         for (SchedulingExecutionVertex vertex : regionToRestart.getVertices()) {
             for (SchedulingResultPartition producedPartition : vertex.getProducedResults()) {
-                final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                        producedPartition.getConsumerVertexGroup();
-                if (consumerVertexGroup.isPresent()
-                        && !visitedConsumerVertexGroups.contains(consumerVertexGroup.get())) {
-                    visitedConsumerVertexGroups.add(consumerVertexGroup.get());
-                    consumerVertexGroupsToVisit.add(consumerVertexGroup.get());
+                for (ConsumerVertexGroup consumerVertexGroup :
+                        producedPartition.getConsumerVertexGroups()) {
+                    if (!visitedConsumerVertexGroups.contains(consumerVertexGroup)) {
+                        visitedConsumerVertexGroups.add(consumerVertexGroup);
+                        consumerVertexGroupsToVisit.add(consumerVertexGroup);
+                    }
                 }
             }
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
index de353be7ba8..6428284cd49 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
@@ -32,7 +32,6 @@ import java.util.HashSet;
 import java.util.IdentityHashMap;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.function.Function;
 
@@ -116,23 +115,20 @@ public final class SchedulingPipelinedRegionComputeUtil {
                     if (producedResult.getResultType().mustBePipelinedConsumed()) {
                         continue;
                     }
-                    final Optional<ConsumerVertexGroup> consumerVertexGroup =
-                            producedResult.getConsumerVertexGroup();
-                    if (!consumerVertexGroup.isPresent()) {
-                        continue;
-                    }
-
-                    for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
-                        SchedulingExecutionVertex consumerVertex =
-                                executionVertexRetriever.apply(consumerVertexId);
-                        // Skip the ConsumerVertexGroup if its vertices are outside current
-                        // regions and cannot be merged
-                        if (!vertexToRegion.containsKey(consumerVertex)) {
-                            break;
-                        }
-                        if (!currentRegion.contains(consumerVertex)) {
-                            currentRegionOutEdges.add(
-                                    regionIndices.get(vertexToRegion.get(consumerVertex)));
+                    for (ConsumerVertexGroup consumerVertexGroup :
+                            producedResult.getConsumerVertexGroups()) {
+                        for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+                            SchedulingExecutionVertex consumerVertex =
+                                    executionVertexRetriever.apply(consumerVertexId);
+                            // Skip the ConsumerVertexGroup if its vertices are outside current
+                            // regions and cannot be merged
+                            if (!vertexToRegion.containsKey(consumerVertex)) {
+                                break;
+                            }
+                            if (!currentRegion.contains(consumerVertex)) {
+                                currentRegionOutEdges.add(
+                                        regionIndices.get(vertexToRegion.get(consumerVertex)));
+                            }
                         }
                     }
                 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
index aa45dd4b56f..2d55cf4c639 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
@@ -56,6 +56,7 @@ import java.util.Collection;
 import java.util.List;
 import java.util.Map;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
@@ -93,7 +94,7 @@ public class NettyShuffleEnvironment
 
     private final FileChannelManager fileChannelManager;
 
-    private final Map<InputGateID, SingleInputGate> inputGatesById;
+    private final Map<InputGateID, Set<SingleInputGate>> inputGatesById;
 
     private final ResultPartitionFactory resultPartitionFactory;
 
@@ -169,7 +170,7 @@ public class NettyShuffleEnvironment
     }
 
     @VisibleForTesting
-    public Optional<InputGate> getInputGate(InputGateID id) {
+    public Optional<Collection<SingleInputGate>> getInputGate(InputGateID id) {
         return Optional.ofNullable(inputGatesById.get(id));
     }
 
@@ -260,8 +261,24 @@ public class NettyShuffleEnvironment
                 InputGateID id =
                         new InputGateID(
                                 igdd.getConsumedResultId(), ownerContext.getExecutionAttemptID());
-                inputGatesById.put(id, inputGate);
-                inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id));
+                Set<SingleInputGate> inputGateSet =
+                        inputGatesById.computeIfAbsent(
+                                id, ignored -> ConcurrentHashMap.newKeySet());
+                inputGateSet.add(inputGate);
+                inputGatesById.put(id, inputGateSet);
+                inputGate
+                        .getCloseFuture()
+                        .thenRun(
+                                () ->
+                                        inputGatesById.computeIfPresent(
+                                                id,
+                                                (key, value) -> {
+                                                    value.remove(inputGate);
+                                                    if (value.isEmpty()) {
+                                                        return null;
+                                                    }
+                                                    return value;
+                                                }));
                 inputGates[gateIndex] = inputGate;
             }
 
@@ -297,17 +314,20 @@ public class NettyShuffleEnvironment
         IntermediateDataSetID intermediateResultPartitionID =
                 partitionInfo.getIntermediateDataSetID();
         InputGateID id = new InputGateID(intermediateResultPartitionID, consumerID);
-        SingleInputGate inputGate = inputGatesById.get(id);
-        if (inputGate == null) {
+        Set<SingleInputGate> inputGates = inputGatesById.get(id);
+        if (inputGates == null || inputGates.isEmpty()) {
             return false;
         }
+
         ShuffleDescriptor shuffleDescriptor = partitionInfo.getShuffleDescriptor();
         checkArgument(
                 shuffleDescriptor instanceof NettyShuffleDescriptor,
                 "Tried to update unknown channel with unknown ShuffleDescriptor %s.",
                 shuffleDescriptor.getClass().getName());
-        inputGate.updateInputChannel(
-                taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+        for (SingleInputGate inputGate : inputGates) {
+            inputGate.updateInputChannel(
+                    taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+        }
         return true;
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
index d6b3abc3b52..2aad2613996 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
@@ -20,7 +20,8 @@ package org.apache.flink.runtime.jobgraph;
 
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 
-import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
@@ -39,11 +40,16 @@ public class IntermediateDataSet implements java.io.Serializable {
 
     private final JobVertex producer; // the operation that produced this data set
 
-    @Nullable private JobEdge consumer;
+    // All consumers must have the same partitioner and parallelism
+    private final List<JobEdge> consumers = new ArrayList<>();
 
     // The type of partition to use at runtime
     private final ResultPartitionType resultType;
 
+    private DistributionPattern distributionPattern;
+
+    private boolean isBroadcast;
+
     // --------------------------------------------------------------------------------------------
 
     public IntermediateDataSet(
@@ -63,9 +69,16 @@ public class IntermediateDataSet implements java.io.Serializable {
         return producer;
     }
 
-    @Nullable
-    public JobEdge getConsumer() {
-        return consumer;
+    public List<JobEdge> getConsumers() {
+        return this.consumers;
+    }
+
+    public boolean isBroadcast() {
+        return isBroadcast;
+    }
+
+    public DistributionPattern getDistributionPattern() {
+        return distributionPattern;
     }
 
     public ResultPartitionType getResultType() {
@@ -75,10 +88,19 @@ public class IntermediateDataSet implements java.io.Serializable {
     // --------------------------------------------------------------------------------------------
 
     public void addConsumer(JobEdge edge) {
-        checkState(
-                this.consumer == null,
-                "Currently one IntermediateDataSet can have at most one consumer.");
-        this.consumer = edge;
+        // sanity check
+        checkState(id.equals(edge.getSourceId()), "Incompatible dataset id.");
+
+        if (consumers.isEmpty()) {
+            distributionPattern = edge.getDistributionPattern();
+            isBroadcast = edge.isBroadcast();
+        } else {
+            checkState(
+                    distributionPattern == edge.getDistributionPattern(),
+                    "Incompatible distribution pattern.");
+            checkState(isBroadcast == edge.isBroadcast(), "Incompatible broadcast type.");
+        }
+        consumers.add(edge);
     }
 
     // --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
index 9772ff4dbb1..4649303c144 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
@@ -44,10 +44,7 @@ public class JobEdge implements java.io.Serializable {
     private SubtaskStateMapper upstreamSubtaskStateMapper = SubtaskStateMapper.ROUND_ROBIN;
 
     /** The data set at the source of the edge, may be null if the edge is not yet connected. */
-    private IntermediateDataSet source;
-
-    /** The id of the source intermediate data set. */
-    private IntermediateDataSetID sourceId;
+    private final IntermediateDataSet source;
 
     /**
      * Optional name for the data shipping strategy (forward, partition hash, rebalance, ...), to be
@@ -55,7 +52,7 @@ public class JobEdge implements java.io.Serializable {
      */
     private String shipStrategyName;
 
-    private boolean isBroadcast;
+    private final boolean isBroadcast;
 
     private boolean isForward;
 
@@ -74,36 +71,20 @@ public class JobEdge implements java.io.Serializable {
      * @param source The data set that is at the source of this edge.
      * @param target The operation that is at the target of this edge.
      * @param distributionPattern The pattern that defines how the connection behaves in parallel.
+     * @param isBroadcast Whether the source broadcasts data to the target.
      */
     public JobEdge(
-            IntermediateDataSet source, JobVertex target, DistributionPattern distributionPattern) {
+            IntermediateDataSet source,
+            JobVertex target,
+            DistributionPattern distributionPattern,
+            boolean isBroadcast) {
         if (source == null || target == null || distributionPattern == null) {
             throw new NullPointerException();
         }
         this.target = target;
         this.distributionPattern = distributionPattern;
         this.source = source;
-        this.sourceId = source.getId();
-    }
-
-    /**
-     * Constructs a new job edge that refers to an intermediate result via the Id, rather than
-     * directly through the intermediate data set structure.
-     *
-     * @param sourceId The id of the data set that is at the source of this edge.
-     * @param target The operation that is at the target of this edge.
-     * @param distributionPattern The pattern that defines how the connection behaves in parallel.
-     */
-    public JobEdge(
-            IntermediateDataSetID sourceId,
-            JobVertex target,
-            DistributionPattern distributionPattern) {
-        if (sourceId == null || target == null || distributionPattern == null) {
-            throw new NullPointerException();
-        }
-        this.target = target;
-        this.distributionPattern = distributionPattern;
-        this.sourceId = sourceId;
+        this.isBroadcast = isBroadcast;
     }
 
     /**
@@ -140,11 +121,7 @@ public class JobEdge implements java.io.Serializable {
      * @return The ID of the consumed data set.
      */
     public IntermediateDataSetID getSourceId() {
-        return sourceId;
-    }
-
-    public boolean isIdReference() {
-        return this.source == null;
+        return source.getId();
     }
 
     // --------------------------------------------------------------------------------------------
@@ -173,11 +150,6 @@ public class JobEdge implements java.io.Serializable {
         return isBroadcast;
     }
 
-    /** Sets whether the edge is broadcast edge. */
-    public void setBroadcast(boolean broadcast) {
-        isBroadcast = broadcast;
-    }
-
     /** Gets whether the edge is forward edge. */
     public boolean isForward() {
         return isForward;
@@ -268,6 +240,6 @@ public class JobEdge implements java.io.Serializable {
 
     @Override
     public String toString() {
-        return String.format("%s --> %s [%s]", sourceId, target, distributionPattern.name());
+        return String.format("%s --> %s [%s]", source.getId(), target, distributionPattern.name());
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
index e8baef4d8e7..bb821adda7b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
@@ -458,10 +458,9 @@ public class JobGraph implements Serializable {
     private void addNodesThatHaveNoNewPredecessors(
             JobVertex start, List<JobVertex> target, Set<JobVertex> remaining) {
 
-        // forward traverse over all produced data sets
+        // forward traverse over all produced data sets and all their consumers
         for (IntermediateDataSet dataSet : start.getProducedDataSets()) {
-            JobEdge edge = dataSet.getConsumer();
-            if (edge != null) {
+            for (JobEdge edge : dataSet.getConsumers()) {
 
                 // a vertex can be added, if it has no predecessors that are still in the
                 // 'remaining' set
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
index 36d1a1e5487..717bd6d4376 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
@@ -36,7 +36,9 @@ import javax.annotation.Nullable;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
@@ -72,8 +74,8 @@ public class JobVertex implements java.io.Serializable {
      */
     private final List<OperatorIDPair> operatorIDs;
 
-    /** List of produced data sets, one per writer. */
-    private final ArrayList<IntermediateDataSet> results = new ArrayList<>();
+    /** Produced data sets, one per writer. */
+    private final Map<IntermediateDataSetID, IntermediateDataSet> results = new LinkedHashMap<>();
 
     /** List of edges with incoming data. One per Reader. */
     private final ArrayList<JobEdge> inputs = new ArrayList<>();
@@ -374,7 +376,7 @@ public class JobVertex implements java.io.Serializable {
     }
 
     public List<IntermediateDataSet> getProducedDataSets() {
-        return this.results;
+        return new ArrayList<>(results.values());
     }
 
     public List<JobEdge> getInputs() {
@@ -481,30 +483,37 @@ public class JobVertex implements java.io.Serializable {
     }
 
     // --------------------------------------------------------------------------------------------
-    public IntermediateDataSet createAndAddResultDataSet(
+    public IntermediateDataSet getOrCreateResultDataSet(
             IntermediateDataSetID id, ResultPartitionType partitionType) {
-
-        IntermediateDataSet result = new IntermediateDataSet(id, partitionType, this);
-        this.results.add(result);
-        return result;
+        return this.results.computeIfAbsent(
+                id, key -> new IntermediateDataSet(id, partitionType, this));
     }
 
     public JobEdge connectNewDataSetAsInput(
             JobVertex input, DistributionPattern distPattern, ResultPartitionType partitionType) {
+        return connectNewDataSetAsInput(input, distPattern, partitionType, false);
+    }
+
+    public JobEdge connectNewDataSetAsInput(
+            JobVertex input,
+            DistributionPattern distPattern,
+            ResultPartitionType partitionType,
+            boolean isBroadcast) {
         return connectNewDataSetAsInput(
-                input, distPattern, partitionType, new IntermediateDataSetID());
+                input, distPattern, partitionType, new IntermediateDataSetID(), isBroadcast);
     }
 
     public JobEdge connectNewDataSetAsInput(
             JobVertex input,
             DistributionPattern distPattern,
             ResultPartitionType partitionType,
-            IntermediateDataSetID intermediateDataSetId) {
+            IntermediateDataSetID intermediateDataSetId,
+            boolean isBroadcast) {
 
         IntermediateDataSet dataSet =
-                input.createAndAddResultDataSet(intermediateDataSetId, partitionType);
+                input.getOrCreateResultDataSet(intermediateDataSetId, partitionType);
 
-        JobEdge edge = new JobEdge(dataSet, this, distPattern);
+        JobEdge edge = new JobEdge(dataSet, this, distPattern, isBroadcast);
         this.inputs.add(edge);
         dataSet.addConsumer(edge);
         return edge;
@@ -525,13 +534,7 @@ public class JobVertex implements java.io.Serializable {
     }
 
     public boolean hasNoConnectedInputs() {
-        for (JobEdge edge : inputs) {
-            if (!edge.isIdReference()) {
-                return false;
-            }
-        }
-
-        return true;
+        return inputs.isEmpty();
     }
 
     public void markContainsSources() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index 0fac885dd3f..1506152fa38 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -45,7 +45,6 @@ import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -143,15 +142,22 @@ public class SsgNetworkMemoryCalculationUtils {
         Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
         List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
 
-        for (int i = 0; i < producedDataSets.size(); i++) {
-            IntermediateDataSet producedDataSet = producedDataSets.get(i);
-            JobEdge outputEdge = checkNotNull(producedDataSet.getConsumer());
-            ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
-            int maxNum =
-                    EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
-                            ejv.getParallelism(),
-                            consumerJobVertex.getParallelism(),
-                            outputEdge.getDistributionPattern());
+        checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph.");
+        for (IntermediateDataSet producedDataSet : producedDataSets) {
+            int maxNum = 0;
+            List<JobEdge> outputEdges = producedDataSet.getConsumers();
+
+            if (!outputEdges.isEmpty()) {
+                // for non-dynamic graph, the consumer vertices' parallelisms and distribution
+                // patterns must be the same
+                JobEdge outputEdge = outputEdges.get(0);
+                ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
+                maxNum =
+                        EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
+                                ejv.getParallelism(),
+                                consumerJobVertex.getParallelism(),
+                                outputEdge.getDistributionPattern());
+            }
             ret.put(producedDataSet.getId(), maxNum);
         }
 
@@ -177,7 +183,9 @@ public class SsgNetworkMemoryCalculationUtils {
                         ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
                 SubpartitionIndexRange subpartitionIndexRange =
                         TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
-                                resultPartition, vertex.getParallelSubtaskIndex());
+                                partitionGroup.getNumConsumers(),
+                                resultPartition,
+                                vertex.getParallelSubtaskIndex());
 
                 ret.merge(
                         partitionGroup.getIntermediateDataSetID(),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
index 486ae28b264..f2fd573e613 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
@@ -262,7 +262,7 @@ public class DefaultExecutionTopology implements SchedulingTopology {
             List<DefaultResultPartition> producedPartitions =
                     generateProducedSchedulingResultPartition(
                             vertex.getProducedPartitions(),
-                            edgeManager::getConsumerVertexGroupForPartition);
+                            edgeManager::getConsumerVertexGroupsForPartition);
 
             producedPartitions.forEach(
                     partition -> resultPartitionsById.put(partition.getId(), partition));
@@ -285,8 +285,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
     private static List<DefaultResultPartition> generateProducedSchedulingResultPartition(
             Map<IntermediateResultPartitionID, IntermediateResultPartition>
                     producedIntermediatePartitions,
-            Function<IntermediateResultPartitionID, ConsumerVertexGroup>
-                    partitionConsumerVertexGroupRetriever) {
+            Function<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+                    partitionConsumerVertexGroupsRetriever) {
 
         List<DefaultResultPartition> producedSchedulingPartitions =
                 new ArrayList<>(producedIntermediatePartitions.size());
@@ -305,8 +305,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
                                                                 ? ResultPartitionState.CONSUMABLE
                                                                 : ResultPartitionState.CREATED,
                                                 () ->
-                                                        partitionConsumerVertexGroupRetriever.apply(
-                                                                irp.getPartitionId()),
+                                                        partitionConsumerVertexGroupsRetriever
+                                                                .apply(irp.getPartitionId()),
                                                 irp::getConsumedPartitionGroups)));
 
         return producedSchedulingPartitions;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
index b46c54bd58e..2ae54af4304 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
@@ -27,7 +27,6 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
 
 import java.util.List;
-import java.util.Optional;
 import java.util.function.Supplier;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -45,7 +44,7 @@ class DefaultResultPartition implements SchedulingResultPartition {
 
     private DefaultExecutionVertex producer;
 
-    private final Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier;
+    private final Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier;
 
     private final Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier;
 
@@ -54,13 +53,13 @@ class DefaultResultPartition implements SchedulingResultPartition {
             IntermediateDataSetID intermediateDataSetId,
             ResultPartitionType partitionType,
             Supplier<ResultPartitionState> resultPartitionStateSupplier,
-            Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier,
+            Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier,
             Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier) {
         this.resultPartitionId = checkNotNull(partitionId);
         this.intermediateDataSetId = checkNotNull(intermediateDataSetId);
         this.partitionType = checkNotNull(partitionType);
         this.resultPartitionStateSupplier = checkNotNull(resultPartitionStateSupplier);
-        this.consumerVertexGroupSupplier = checkNotNull(consumerVertexGroupSupplier);
+        this.consumerVertexGroupsSupplier = checkNotNull(consumerVertexGroupsSupplier);
         this.consumerPartitionGroupSupplier = checkNotNull(consumerPartitionGroupSupplier);
     }
 
@@ -90,8 +89,8 @@ class DefaultResultPartition implements SchedulingResultPartition {
     }
 
     @Override
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
-        return Optional.ofNullable(consumerVertexGroupSupplier.get());
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return checkNotNull(consumerVertexGroupsSupplier.get());
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
index 6e4672c48e9..2f5be93ca51 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
@@ -42,12 +42,17 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
 
     private final ResultPartitionType resultPartitionType;
 
+    /** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */
+    private final int numConsumers;
+
     private ConsumedPartitionGroup(
+            int numConsumers,
             List<IntermediateResultPartitionID> resultPartitions,
             ResultPartitionType resultPartitionType) {
         checkArgument(
                 resultPartitions.size() > 0,
                 "The size of result partitions in the ConsumedPartitionGroup should be larger than 0.");
+        this.numConsumers = numConsumers;
         this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID();
         this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType);
 
@@ -63,16 +68,18 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
     }
 
     public static ConsumedPartitionGroup fromMultiplePartitions(
+            int numConsumers,
             List<IntermediateResultPartitionID> resultPartitions,
             ResultPartitionType resultPartitionType) {
-        return new ConsumedPartitionGroup(resultPartitions, resultPartitionType);
+        return new ConsumedPartitionGroup(numConsumers, resultPartitions, resultPartitionType);
     }
 
     public static ConsumedPartitionGroup fromSinglePartition(
+            int numConsumers,
             IntermediateResultPartitionID resultPartition,
             ResultPartitionType resultPartitionType) {
         return new ConsumedPartitionGroup(
-                Collections.singletonList(resultPartition), resultPartitionType);
+                numConsumers, Collections.singletonList(resultPartition), resultPartitionType);
     }
 
     @Override
@@ -88,6 +95,14 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
         return resultPartitions.isEmpty();
     }
 
+    /**
+     * In dynamic graph cases, the number of consumers of ConsumedPartitionGroup can be different
+     * even if they contain the same IntermediateResultPartition.
+     */
+    public int getNumConsumers() {
+        return numConsumers;
+    }
+
     public IntermediateResultPartitionID getFirst() {
         return iterator().next();
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
index b52150c96cd..86cafebac6b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
@@ -24,7 +24,6 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.topology.Result;
 
 import java.util.List;
-import java.util.Optional;
 
 /** Representation of {@link IntermediateResultPartition}. */
 public interface SchedulingResultPartition
@@ -49,11 +48,11 @@ public interface SchedulingResultPartition
     ResultPartitionState getState();
 
     /**
-     * Gets the {@link ConsumerVertexGroup}.
+     * Gets the {@link ConsumerVertexGroup}s.
      *
-     * @return {@link ConsumerVertexGroup} if consumers exists, otherwise {@link Optional#empty()}.
+     * @return list of {@link ConsumerVertexGroup}s
      */
-    Optional<ConsumerVertexGroup> getConsumerVertexGroup();
+    List<ConsumerVertexGroup> getConsumerVertexGroups();
 
     /**
      * Gets the {@link ConsumedPartitionGroup}s this partition belongs to.
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
index 4f7d01e4e8f..4d217a9eb97 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
@@ -24,12 +24,12 @@ import org.apache.flink.runtime.scheduler.SchedulerOperations;
 import org.apache.flink.runtime.scheduler.SchedulingTopologyListener;
 import org.apache.flink.util.IterableUtils;
 
+import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -85,11 +85,9 @@ public class VertexwiseSchedulingStrategy
 
             Set<ExecutionVertexID> consumerVertices =
                     IterableUtils.toStream(executionVertex.getProducedResults())
-                            .map(SchedulingResultPartition::getConsumerVertexGroup)
-                            .filter(Optional::isPresent)
-                            .flatMap(
-                                    consumerVertexGroup ->
-                                            IterableUtils.toStream(consumerVertexGroup.get()))
+                            .map(SchedulingResultPartition::getConsumerVertexGroups)
+                            .flatMap(Collection::stream)
+                            .flatMap(IterableUtils::toStream)
                             .collect(Collectors.toSet());
 
             maybeScheduleVertices(consumerVertices);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
new file mode 100644
index 00000000000..af1c543c509
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.blob.TestingBlobWriter;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
+import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createSchedulerAndDeploy;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests that blocking result partitions are properly released. */
+class BlockingResultPartitionReleaseTest {
+
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
+
+    private ScheduledExecutorService scheduledExecutorService;
+    private ComponentMainThreadExecutor mainThreadExecutor;
+    private ManuallyTriggeredScheduledExecutorService ioExecutor;
+
+    @BeforeEach
+    void setup() {
+        scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
+        mainThreadExecutor =
+                ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(
+                        scheduledExecutorService);
+        ioExecutor = new ManuallyTriggeredScheduledExecutorService();
+    }
+
+    @AfterEach
+    void teardown() {
+        if (scheduledExecutorService != null) {
+            scheduledExecutorService.shutdownNow();
+        }
+    }
+
+    @Test
+    void testMultipleConsumersForAdaptiveBatchScheduler() throws Exception {
+        testResultPartitionConsumedByMultiConsumers(true);
+    }
+
+    @Test
+    void testMultipleConsumersForDefaultScheduler() throws Exception {
+        testResultPartitionConsumedByMultiConsumers(false);
+    }
+
+    private void testResultPartitionConsumedByMultiConsumers(boolean isAdaptive) throws Exception {
+        int parallelism = 2;
+        JobID jobId = new JobID();
+        JobVertex producer = ExecutionGraphTestUtils.createNoOpVertex("producer", parallelism);
+        JobVertex consumer1 = ExecutionGraphTestUtils.createNoOpVertex("consumer1", parallelism);
+        JobVertex consumer2 = ExecutionGraphTestUtils.createNoOpVertex("consumer2", parallelism);
+
+        TestingPartitionTracker partitionTracker = new TestingPartitionTracker();
+        SchedulerBase scheduler =
+                createSchedulerAndDeploy(
+                        isAdaptive,
+                        jobId,
+                        producer,
+                        new JobVertex[] {consumer1, consumer2},
+                        DistributionPattern.ALL_TO_ALL,
+                        new TestingBlobWriter(Integer.MAX_VALUE),
+                        mainThreadExecutor,
+                        ioExecutor,
+                        partitionTracker,
+                        EXECUTOR_RESOURCE.getExecutor());
+        ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+
+        assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+        CompletableFuture.runAsync(
+                        () -> finishJobVertex(executionGraph, consumer1.getID()),
+                        mainThreadExecutor)
+                .join();
+        ioExecutor.triggerAll();
+
+        assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+        CompletableFuture.runAsync(
+                        () -> finishJobVertex(executionGraph, consumer2.getID()),
+                        mainThreadExecutor)
+                .join();
+        ioExecutor.triggerAll();
+
+        assertThat(partitionTracker.releasedPartitions.size()).isEqualTo(parallelism);
+        for (int i = 0; i < parallelism; ++i) {
+            ExecutionJobVertex ejv = checkNotNull(executionGraph.getJobVertex(producer.getID()));
+            assertThat(
+                            partitionTracker.releasedPartitions.stream()
+                                    .map(ResultPartitionID::getPartitionId))
+                    .containsExactlyInAnyOrder(
+                            Arrays.stream(ejv.getProducedDataSets()[0].getPartitions())
+                                    .map(IntermediateResultPartition::getPartitionId)
+                                    .toArray(IntermediateResultPartitionID[]::new));
+        }
+    }
+
+    private static class TestingPartitionTracker extends NoOpJobMasterPartitionTracker {
+
+        private final List<ResultPartitionID> releasedPartitions = new ArrayList<>();
+
+        @Override
+        public void stopTrackingAndReleasePartitions(
+                Collection<ResultPartitionID> resultPartitionIds, boolean releaseOnShuffleMaster) {
+            releasedPartitions.addAll(checkNotNull(resultPartitionIds));
+        }
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index 9245fa8f966..588a91c6fd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.io.InputSplitSource;
 import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -48,6 +49,7 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -261,6 +263,45 @@ class DefaultExecutionGraphConstructionTest {
         assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2);
     }
 
+    @Test
+    void testMultiConsumersForOneIntermediateResult() throws Exception {
+        JobVertex v1 = new JobVertex("vertex1");
+        JobVertex v2 = new JobVertex("vertex2");
+        JobVertex v3 = new JobVertex("vertex3");
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        v2.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+        v3.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+
+        List<JobVertex> vertices = new ArrayList<>(Arrays.asList(v1, v2, v3));
+        ExecutionGraph eg = createDefaultExecutionGraph(vertices);
+        eg.attachJobGraph(vertices);
+
+        ExecutionJobVertex ejv1 = checkNotNull(eg.getJobVertex(v1.getID()));
+        assertThat(ejv1.getProducedDataSets()).hasSize(1);
+        assertThat(ejv1.getProducedDataSets()[0].getId()).isEqualTo(dataSetId);
+
+        ExecutionJobVertex ejv2 = checkNotNull(eg.getJobVertex(v2.getID()));
+        assertThat(ejv2.getInputs()).hasSize(1);
+        assertThat(ejv2.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+        ExecutionJobVertex ejv3 = checkNotNull(eg.getJobVertex(v3.getID()));
+        assertThat(ejv3.getInputs()).hasSize(1);
+        assertThat(ejv3.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+        List<ConsumedPartitionGroup> partitionGroups1 =
+                ejv2.getTaskVertices()[0].getAllConsumedPartitionGroups();
+        assertThat(partitionGroups1).hasSize(1);
+        assertThat(partitionGroups1.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+
+        List<ConsumedPartitionGroup> partitionGroups2 =
+                ejv3.getTaskVertices()[0].getAllConsumedPartitionGroups();
+        assertThat(partitionGroups2).hasSize(1);
+        assertThat(partitionGroups2.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+    }
+
     @Test
     void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
         JobVertex v1 = new JobVertex("source");
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index e3b603a8790..28b44595a26 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -85,7 +85,7 @@ class EdgeManagerBuildUtilTest {
 
             IntermediateResultPartition partition =
                     ev.getProducedPartitions().values().iterator().next();
-            ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroup();
+            ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroups().get(0);
             int actual = consumerVertexGroup.size();
             if (actual > actualMaxForUpstream) {
                 actualMaxForUpstream = actual;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
index 81165da6193..ad94749e7e8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
@@ -34,9 +34,15 @@ import org.apache.flink.testutils.executor.TestExecutorExtension;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
+import java.util.Arrays;
+import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
 
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link EdgeManager}. */
@@ -50,23 +56,7 @@ class EdgeManagerTest {
     void testGetConsumedPartitionGroup() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
-
-        v1.setParallelism(2);
-        v2.setParallelism(2);
-
-        v1.setInvokableClass(NoOpInvokable.class);
-        v2.setInvokableClass(NoOpInvokable.class);
-
-        v2.connectNewDataSetAsInput(
-                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
-
-        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(v1, v2);
-        SchedulerBase scheduler =
-                SchedulerTestingUtils.createScheduler(
-                        jobGraph,
-                        ComponentMainThreadExecutorServiceAdapter.forMainThread(),
-                        EXECUTOR_RESOURCE.getExecutor());
-        ExecutionGraph eg = scheduler.getExecutionGraph();
+        ExecutionGraph eg = buildExecutionGraph(v1, v2, 2, 2, ALL_TO_ALL);
 
         ConsumedPartitionGroup groupRetrievedByDownstreamVertex =
                 Objects.requireNonNull(eg.getJobVertex(v2.getID()))
@@ -86,9 +76,7 @@ class EdgeManagerTest {
                 .isEqualTo(groupRetrievedByDownstreamVertex);
 
         ConsumedPartitionGroup groupRetrievedByScheduledResultPartition =
-                scheduler
-                        .getExecutionGraph()
-                        .getSchedulingTopology()
+                eg.getSchedulingTopology()
                         .getResultPartition(consumedPartition.getPartitionId())
                         .getConsumedPartitionGroups()
                         .get(0);
@@ -96,4 +84,64 @@ class EdgeManagerTest {
         assertThat(groupRetrievedByScheduledResultPartition)
                 .isEqualTo(groupRetrievedByDownstreamVertex);
     }
+
+    @Test
+    public void testCalculateNumberOfConsumers() throws Exception {
+        testCalculateNumberOfConsumers(5, 2, ALL_TO_ALL, new int[] {2, 2});
+        testCalculateNumberOfConsumers(5, 2, POINTWISE, new int[] {1, 1});
+        testCalculateNumberOfConsumers(2, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+        testCalculateNumberOfConsumers(2, 5, POINTWISE, new int[] {3, 3, 3, 2, 2});
+        testCalculateNumberOfConsumers(5, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+        testCalculateNumberOfConsumers(5, 5, POINTWISE, new int[] {1, 1, 1, 1, 1});
+    }
+
+    private void testCalculateNumberOfConsumers(
+            int producerParallelism,
+            int consumerParallelism,
+            DistributionPattern distributionPattern,
+            int[] expectedConsumers)
+            throws Exception {
+        JobVertex producer = new JobVertex("producer");
+        JobVertex consumer = new JobVertex("consumer");
+        ExecutionGraph eg =
+                buildExecutionGraph(
+                        producer,
+                        consumer,
+                        producerParallelism,
+                        consumerParallelism,
+                        distributionPattern);
+        List<ConsumedPartitionGroup> partitionGroups =
+                Arrays.stream(checkNotNull(eg.getJobVertex(consumer.getID())).getTaskVertices())
+                        .flatMap(ev -> ev.getAllConsumedPartitionGroups().stream())
+                        .collect(Collectors.toList());
+        int index = 0;
+        for (ConsumedPartitionGroup partitionGroup : partitionGroups) {
+            assertThat(partitionGroup.getNumConsumers()).isEqualTo(expectedConsumers[index++]);
+        }
+    }
+
+    private ExecutionGraph buildExecutionGraph(
+            JobVertex producer,
+            JobVertex consumer,
+            int producerParallelism,
+            int consumerParallelism,
+            DistributionPattern distributionPattern)
+            throws Exception {
+        producer.setParallelism(producerParallelism);
+        consumer.setParallelism(consumerParallelism);
+
+        producer.setInvokableClass(NoOpInvokable.class);
+        consumer.setInvokableClass(NoOpInvokable.class);
+
+        consumer.connectNewDataSetAsInput(
+                producer, distributionPattern, ResultPartitionType.BLOCKING);
+
+        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(producer, consumer);
+        SchedulerBase scheduler =
+                SchedulerTestingUtils.createScheduler(
+                        jobGraph,
+                        ComponentMainThreadExecutorServiceAdapter.forMainThread(),
+                        EXECUTOR_RESOURCE.getExecutor());
+        return scheduler.getExecutionGraph();
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
index ff2baf807e1..a9e7df9470d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
 
@@ -43,6 +44,7 @@ import javax.annotation.Nullable;
 import java.lang.reflect.Field;
 import java.time.Duration;
 import java.util.List;
+import java.util.Objects;
 import java.util.Random;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
@@ -249,6 +251,23 @@ public class ExecutionGraphTestUtils {
         }
     }
 
+    public static void finishJobVertex(ExecutionGraph executionGraph, JobVertexID jobVertexId) {
+        for (ExecutionVertex vertex :
+                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexId))
+                        .getTaskVertices()) {
+            finishExecutionVertex(executionGraph, vertex);
+        }
+    }
+
+    public static void finishExecutionVertex(
+            ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
+        executionGraph.updateState(
+                new TaskExecutionStateTransition(
+                        new TaskExecutionState(
+                                executionVertex.getCurrentExecutionAttempt().getAttemptId(),
+                                ExecutionState.FINISHED)));
+    }
+
     /**
      * Takes all vertices in the given ExecutionGraph and switches their current execution to
      * FINISHED.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index d847e9ede68..dda16596751 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -170,7 +170,7 @@ class ExecutionJobVertexTest {
             int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception {
         JobVertex jobVertex = new JobVertex("testVertex");
         jobVertex.setInvokableClass(AbstractInvokable.class);
-        jobVertex.createAndAddResultDataSet(
+        jobVertex.getOrCreateResultDataSet(
                 new IntermediateDataSetID(), ResultPartitionType.BLOCKING);
 
         if (maxParallelism > 0) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 5ff5d9f201c..aa906caf674 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
 import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
@@ -151,6 +152,32 @@ public class IntermediateResultPartitionTest {
         assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
+    @Test
+    void testReleasePartitionGroups() throws Exception {
+        IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
+
+        IntermediateResultPartition partition1 = result.getPartitions()[0];
+        IntermediateResultPartition partition2 = result.getPartitions()[1];
+        assertThat(partition1.canBeReleased()).isFalse();
+        assertThat(partition2.canBeReleased()).isFalse();
+
+        List<ConsumedPartitionGroup> consumedPartitionGroup1 =
+                partition1.getConsumedPartitionGroups();
+        List<ConsumedPartitionGroup> consumedPartitionGroup2 =
+                partition2.getConsumedPartitionGroups();
+        assertThat(consumedPartitionGroup1).isEqualTo(consumedPartitionGroup2);
+
+        assertThat(consumedPartitionGroup1).hasSize(2);
+        partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(0));
+        assertThat(partition1.canBeReleased()).isFalse();
+
+        partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(1));
+        assertThat(partition1.canBeReleased()).isTrue();
+
+        result.resetForNewExecution();
+        assertThat(partition1.canBeReleased()).isFalse();
+    }
+
     @Test
     void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7));
@@ -245,11 +272,24 @@ public class IntermediateResultPartitionTest {
             v2.setMaxParallelism(consumerMaxParallelism);
         }
 
-        v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
+        final JobVertex v3 = new JobVertex("v3");
+        v3.setInvokableClass(NoOpInvokable.class);
+        if (consumerParallelism > 0) {
+            v3.setParallelism(consumerParallelism);
+        }
+        if (consumerMaxParallelism > 0) {
+            v3.setMaxParallelism(consumerMaxParallelism);
+        }
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        v2.connectNewDataSetAsInput(
+                v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+        v3.connectNewDataSetAsInput(
+                v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
 
         final JobGraph jobGraph =
                 JobGraphBuilder.newBatchJobGraphBuilder()
-                        .addJobVertices(Arrays.asList(v1, v2))
+                        .addJobVertices(Arrays.asList(v1, v2, v3))
                         .build();
 
         final Configuration configuration = new Configuration();
@@ -287,15 +327,23 @@ public class IntermediateResultPartitionTest {
         source.setInvokableClass(NoOpInvokable.class);
         source.setParallelism(parallelism);
 
-        JobVertex sink = new JobVertex("v2");
-        sink.setInvokableClass(NoOpInvokable.class);
-        sink.setParallelism(parallelism);
+        JobVertex sink1 = new JobVertex("v2");
+        sink1.setInvokableClass(NoOpInvokable.class);
+        sink1.setParallelism(parallelism);
+
+        JobVertex sink2 = new JobVertex("v3");
+        sink2.setInvokableClass(NoOpInvokable.class);
+        sink2.setParallelism(parallelism);
 
-        sink.connectNewDataSetAsInput(source, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        sink1.connectNewDataSetAsInput(
+                source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
+        sink2.connectNewDataSetAsInput(
+                source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
 
         ScheduledExecutorService executorService = new DirectScheduledExecutorService();
 
-        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink);
+        JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink1, sink2);
 
         SchedulerBase scheduler =
                 new DefaultSchedulerBuilder(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
index 7c00da22a85..4bf976c76b0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.executiongraph;
 
 import org.apache.flink.api.common.JobID;
-import org.apache.flink.runtime.JobException;
 import org.apache.flink.runtime.blob.BlobWriter;
 import org.apache.flink.runtime.blob.TestingBlobWriter;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
@@ -27,21 +26,14 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAda
 import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
 import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
 import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
-import org.apache.flink.runtime.jobgraph.JobGraph;
-import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
 import org.apache.flink.runtime.jobgraph.JobVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobmaster.LogicalSlot;
-import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
-import org.apache.flink.runtime.scheduler.DefaultScheduler;
-import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
-import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
 
@@ -50,17 +42,16 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
-import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
 
 import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishExecutionVertex;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /**
@@ -119,7 +110,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -132,8 +123,7 @@ class RemoveCachedShuffleDescriptorTest {
 
         // For the all-to-all edge, we transition all downstream tasks to finished
         CompletableFuture.runAsync(
-                        () -> transitionTasksToFinished(executionGraph, v2.getID()),
-                        mainThreadExecutor)
+                        () -> finishJobVertex(executionGraph, v2.getID()), mainThreadExecutor)
                 .join();
         ioExecutor.triggerAll();
 
@@ -164,7 +154,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -206,7 +196,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -222,7 +212,7 @@ class RemoveCachedShuffleDescriptorTest {
                 Objects.requireNonNull(executionGraph.getJobVertex(v2.getID()))
                         .getTaskVertices()[0];
         CompletableFuture.runAsync(
-                        () -> transitionTaskToFinished(executionGraph, ev21), mainThreadExecutor)
+                        () -> finishExecutionVertex(executionGraph, ev21), mainThreadExecutor)
                 .join();
         ioExecutor.triggerAll();
 
@@ -263,7 +253,7 @@ class RemoveCachedShuffleDescriptorTest {
         final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
         final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
 
-        final DefaultScheduler scheduler =
+        final SchedulerBase scheduler =
                 createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
         final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
 
@@ -292,42 +282,27 @@ class RemoveCachedShuffleDescriptorTest {
         assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
-    private DefaultScheduler createSchedulerAndDeploy(
+    private SchedulerBase createSchedulerAndDeploy(
             JobID jobId,
             JobVertex v1,
             JobVertex v2,
             DistributionPattern distributionPattern,
             BlobWriter blobWriter)
             throws Exception {
-
-        v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
-
-        final List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2));
-        final DefaultScheduler scheduler =
-                createScheduler(jobId, ordered, blobWriter, mainThreadExecutor, ioExecutor);
-        final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
-        final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
-
-        CompletableFuture.runAsync(
-                        () -> {
-                            try {
-                                // Deploy upstream source vertices
-                                deployTasks(executionGraph, v1.getID(), slotBuilder);
-                                // Transition upstream vertices into FINISHED
-                                transitionTasksToFinished(executionGraph, v1.getID());
-                                // Deploy downstream sink vertices
-                                deployTasks(executionGraph, v2.getID(), slotBuilder);
-                            } catch (Exception e) {
-                                throw new RuntimeException("Exceptions shouldn't happen here.", e);
-                            }
-                        },
-                        mainThreadExecutor)
-                .join();
-
-        return scheduler;
+        return SchedulerTestingUtils.createSchedulerAndDeploy(
+                false,
+                jobId,
+                v1,
+                new JobVertex[] {v2},
+                distributionPattern,
+                blobWriter,
+                mainThreadExecutor,
+                ioExecutor,
+                NoOpJobMasterPartitionTracker.INSTANCE,
+                EXECUTOR_RESOURCE.getExecutor());
     }
 
-    private void triggerGlobalFailoverAndComplete(DefaultScheduler scheduler, JobVertex upstream)
+    private void triggerGlobalFailoverAndComplete(SchedulerBase scheduler, JobVertex upstream)
             throws TimeoutException {
 
         final Throwable t = new Exception();
@@ -378,66 +353,6 @@ class RemoveCachedShuffleDescriptorTest {
 
     // ============== Utils ==============
 
-    private static DefaultScheduler createScheduler(
-            final JobID jobId,
-            final List<JobVertex> jobVertices,
-            final BlobWriter blobWriter,
-            final ComponentMainThreadExecutor mainThreadExecutor,
-            final ScheduledExecutorService ioExecutor)
-            throws Exception {
-        final JobGraph jobGraph =
-                JobGraphBuilder.newBatchJobGraphBuilder()
-                        .setJobId(jobId)
-                        .addJobVertices(jobVertices)
-                        .build();
-
-        return new DefaultSchedulerBuilder(
-                        jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor())
-                .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
-                .setBlobWriter(blobWriter)
-                .setIoExecutor(ioExecutor)
-                .build();
-    }
-
-    private static void deployTasks(
-            ExecutionGraph executionGraph,
-            JobVertexID jobVertexID,
-            TestingLogicalSlotBuilder slotBuilder)
-            throws JobException, ExecutionException, InterruptedException {
-
-        for (ExecutionVertex vertex :
-                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
-                        .getTaskVertices()) {
-            LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
-
-            Execution execution = vertex.getCurrentExecutionAttempt();
-            execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
-            execution.transitionState(ExecutionState.SCHEDULED);
-
-            vertex.tryAssignResource(slot);
-            vertex.deploy();
-        }
-    }
-
-    private static void transitionTasksToFinished(
-            ExecutionGraph executionGraph, JobVertexID jobVertexID) {
-
-        for (ExecutionVertex vertex :
-                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
-                        .getTaskVertices()) {
-            transitionTaskToFinished(executionGraph, vertex);
-        }
-    }
-
-    private static void transitionTaskToFinished(
-            ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
-        executionGraph.updateState(
-                new TaskExecutionStateTransition(
-                        new TaskExecutionState(
-                                executionVertex.getCurrentExecutionAttempt().getAttemptId(),
-                                ExecutionState.FINISHED)));
-    }
-
     private static MaybeOffloaded<ShuffleDescriptor[]> getConsumedCachedShuffleDescriptor(
             ExecutionGraph executionGraph, JobVertex vertex) {
         return getConsumedCachedShuffleDescriptor(executionGraph, vertex, 0);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
index 087a29613e1..130212bcf05 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
@@ -28,8 +28,9 @@ import java.util.Collections;
 import java.util.List;
 
 /** No-op implementation of {@link JobMasterPartitionTracker}. */
-public enum NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
-    INSTANCE;
+public class NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
+    public static final NoOpJobMasterPartitionTracker INSTANCE =
+            new NoOpJobMasterPartitionTracker();
 
     public static final PartitionTrackerFactory FACTORY = lookup -> INSTANCE;
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 66e097b34b1..cf59753999f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test;
 import java.io.IOException;
 import java.net.URL;
 import java.net.URLClassLoader;
+import java.util.List;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -41,6 +42,45 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 @SuppressWarnings("serial")
 class JobTaskVertexTest {
 
+    @Test
+    void testMultipleConsumersVertices() {
+        JobVertex producer = new JobVertex("producer");
+        JobVertex consumer1 = new JobVertex("consumer1");
+        JobVertex consumer2 = new JobVertex("consumer2");
+
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        consumer1.connectNewDataSetAsInput(
+                producer,
+                DistributionPattern.ALL_TO_ALL,
+                ResultPartitionType.BLOCKING,
+                dataSetId,
+                false);
+        consumer2.connectNewDataSetAsInput(
+                producer,
+                DistributionPattern.ALL_TO_ALL,
+                ResultPartitionType.BLOCKING,
+                dataSetId,
+                false);
+
+        JobVertex consumer3 = new JobVertex("consumer3");
+        consumer3.connectNewDataSetAsInput(
+                producer, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
+
+        assertThat(producer.getProducedDataSets()).hasSize(2);
+
+        IntermediateDataSet dataSet = producer.getProducedDataSets().get(0);
+        assertThat(dataSet.getId()).isEqualTo(dataSetId);
+
+        List<JobEdge> consumers1 = dataSet.getConsumers();
+        assertThat(consumers1).hasSize(2);
+        assertThat(consumers1.get(0).getTarget().getID()).isEqualTo(consumer1.getID());
+        assertThat(consumers1.get(1).getTarget().getID()).isEqualTo(consumer2.getID());
+
+        List<JobEdge> consumers2 = producer.getProducedDataSets().get(1).getConsumers();
+        assertThat(consumers2).hasSize(1);
+        assertThat(consumers2.get(0).getTarget().getID()).isEqualTo(consumer3.getID());
+    }
+
     @Test
     void testConnectDirectly() {
         JobVertex source = new JobVertex("source");
@@ -59,7 +99,8 @@ class JobTaskVertexTest {
         assertThat(source.getProducedDataSets().get(0))
                 .isEqualTo(target.getInputs().get(0).getSource());
 
-        assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target);
+        assertThat(source.getProducedDataSets().get(0).getConsumers().get(0).getTarget())
+                .isEqualTo(target);
     }
 
     @Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
index 7f9488fe901..a0ce9132e3f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
@@ -206,7 +206,8 @@ public class JobIntermediateDatasetReuseTest {
                 sender,
                 DistributionPattern.POINTWISE,
                 ResultPartitionType.BLOCKING_PERSISTENT,
-                intermediateDataSetID);
+                intermediateDataSetID,
+                false);
 
         return new JobGraph(null, "First Job", sender, receiver);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
index eb09a7e139c..b3954d2ea07 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
@@ -20,6 +20,8 @@ package org.apache.flink.runtime.scheduler;
 
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.time.Time;
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.blob.BlobWriter;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
 import org.apache.flink.runtime.checkpoint.CheckpointException;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
@@ -28,12 +30,24 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
 import org.apache.flink.runtime.checkpoint.PendingCheckpoint;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
 import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
+import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
+import org.apache.flink.runtime.jobmaster.LogicalSlot;
+import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
 import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProvider;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
@@ -46,12 +60,18 @@ import org.apache.flink.util.TernaryBoolean;
 import javax.annotation.Nullable;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 import java.util.stream.StreamSupport;
 
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertTrue;
@@ -297,4 +317,121 @@ public class SchedulerTestingUtils {
                 allocationTimeout,
                 new LocalInputPreferredSlotSharingStrategy.Factory());
     }
+
+    public static SchedulerBase createSchedulerAndDeploy(
+            boolean isAdaptive,
+            JobID jobId,
+            JobVertex producer,
+            JobVertex[] consumers,
+            DistributionPattern distributionPattern,
+            BlobWriter blobWriter,
+            ComponentMainThreadExecutor mainThreadExecutor,
+            ScheduledExecutorService ioExecutor,
+            JobMasterPartitionTracker partitionTracker,
+            ScheduledExecutorService scheduledExecutor)
+            throws Exception {
+        final List<JobVertex> vertices = new ArrayList<>(Collections.singletonList(producer));
+        IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+        for (JobVertex consumer : consumers) {
+            consumer.connectNewDataSetAsInput(
+                    producer, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+            vertices.add(consumer);
+        }
+
+        final SchedulerBase scheduler =
+                createScheduler(
+                        isAdaptive,
+                        jobId,
+                        vertices,
+                        blobWriter,
+                        mainThreadExecutor,
+                        ioExecutor,
+                        partitionTracker,
+                        scheduledExecutor);
+        final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+        final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
+
+        CompletableFuture.runAsync(
+                        () -> {
+                            try {
+                                if (isAdaptive) {
+                                    initializeExecutionJobVertex(producer.getID(), executionGraph);
+                                }
+                                // Deploy upstream source vertices
+                                deployTasks(executionGraph, producer.getID(), slotBuilder);
+                                // Transition upstream vertices into FINISHED
+                                finishJobVertex(executionGraph, producer.getID());
+                                // Deploy downstream sink vertices
+                                for (JobVertex consumer : consumers) {
+                                    if (isAdaptive) {
+                                        initializeExecutionJobVertex(
+                                                consumer.getID(), executionGraph);
+                                    }
+                                    deployTasks(executionGraph, consumer.getID(), slotBuilder);
+                                }
+                            } catch (Exception e) {
+                                throw new RuntimeException("Exceptions shouldn't happen here.", e);
+                            }
+                        },
+                        mainThreadExecutor)
+                .join();
+        return scheduler;
+    }
+
+    private static void initializeExecutionJobVertex(
+            JobVertexID jobVertex, ExecutionGraph executionGraph) {
+        try {
+            executionGraph.initializeJobVertex(
+                    executionGraph.getJobVertex(jobVertex), System.currentTimeMillis());
+            executionGraph.notifyNewlyInitializedJobVertices(
+                    Collections.singletonList(executionGraph.getJobVertex(jobVertex)));
+        } catch (JobException exception) {
+            throw new RuntimeException(exception);
+        }
+    }
+
+    private static DefaultScheduler createScheduler(
+            boolean isAdaptive,
+            JobID jobId,
+            List<JobVertex> jobVertices,
+            BlobWriter blobWriter,
+            ComponentMainThreadExecutor mainThreadExecutor,
+            ScheduledExecutorService ioExecutor,
+            JobMasterPartitionTracker partitionTracker,
+            ScheduledExecutorService scheduledExecutor)
+            throws Exception {
+        final JobGraph jobGraph =
+                JobGraphBuilder.newBatchJobGraphBuilder()
+                        .setJobId(jobId)
+                        .addJobVertices(jobVertices)
+                        .build();
+
+        final DefaultSchedulerBuilder builder =
+                new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, scheduledExecutor)
+                        .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
+                        .setBlobWriter(blobWriter)
+                        .setIoExecutor(ioExecutor)
+                        .setPartitionTracker(partitionTracker);
+        return isAdaptive ? builder.buildAdaptiveBatchJobScheduler() : builder.build();
+    }
+
+    private static void deployTasks(
+            ExecutionGraph executionGraph,
+            JobVertexID jobVertexID,
+            TestingLogicalSlotBuilder slotBuilder)
+            throws JobException, ExecutionException, InterruptedException {
+
+        for (ExecutionVertex vertex :
+                Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
+                        .getTaskVertices()) {
+            LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
+
+            Execution execution = vertex.getCurrentExecutionAttempt();
+            execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
+            execution.transitionState(ExecutionState.SCHEDULED);
+
+            vertex.tryAssignResource(slot);
+            vertex.deploy();
+        }
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
index 4e1d9cb76a4..f22c3b23dd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
@@ -52,7 +52,6 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
@@ -314,17 +313,23 @@ class DefaultExecutionTopologyTest {
 
             assertPartitionEquals(originalPartition, adaptedPartition);
 
-            ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup();
-            Optional<ConsumerVertexGroup> adaptedConsumers =
-                    adaptedPartition.getConsumerVertexGroup();
-            assertThat(adaptedConsumers).isPresent();
-            for (ExecutionVertexID originalId : consumerVertexGroup) {
+            List<ExecutionVertexID> originalConsumerIds = new ArrayList<>();
+            for (ConsumerVertexGroup consumerVertexGroup :
+                    originalPartition.getConsumerVertexGroups()) {
+                for (ExecutionVertexID executionVertexId : consumerVertexGroup) {
+                    originalConsumerIds.add(executionVertexId);
+                }
+            }
+            List<ConsumerVertexGroup> adaptedConsumers = adaptedPartition.getConsumerVertexGroups();
+            assertThat(adaptedConsumers).isNotEmpty();
+            for (ExecutionVertexID originalId : originalConsumerIds) {
                 // it is sufficient to verify that some vertex exists with the correct ID here,
                 // since deep equality is verified later in the main loop
                 // this DOES rely on an implicit assumption that the vertices objects returned by
                 // the topology are
                 // identical to those stored in the partition
-                assertThat(adaptedConsumers.get()).contains(originalId);
+                assertThat(adaptedConsumers.stream().flatMap(IterableUtils::toStream))
+                        .contains(originalId);
             }
         }
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
index fc2df697464..d6fb9ace1a3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
@@ -81,6 +81,7 @@ class DefaultExecutionVertexTest {
         List<ConsumedPartitionGroup> consumedPartitionGroups =
                 Collections.singletonList(
                         ConsumedPartitionGroup.fromSinglePartition(
+                                1,
                                 intermediateResultPartitionId,
                                 schedulingResultPartition.getResultType()));
         Map<IntermediateResultPartitionID, DefaultResultPartition> resultPartitionById =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
index 6d0024626c1..9f8c58400e3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
@@ -28,7 +28,10 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.function.Supplier;
 
@@ -47,8 +50,8 @@ class DefaultResultPartitionTest {
 
     private DefaultResultPartition resultPartition;
 
-    private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> consumerVertexGroups =
-            new HashMap<>();
+    private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+            consumerVertexGroups = new HashMap<>();
 
     @BeforeEach
     void setUp() {
@@ -58,7 +61,9 @@ class DefaultResultPartitionTest {
                         intermediateResultId,
                         BLOCKING,
                         resultPartitionState,
-                        () -> consumerVertexGroups.get(resultPartitionId),
+                        () ->
+                                consumerVertexGroups.computeIfAbsent(
+                                        resultPartitionId, ignored -> new ArrayList<>()),
                         () -> {
                             throw new UnsupportedOperationException();
                         });
@@ -75,14 +80,15 @@ class DefaultResultPartitionTest {
     @Test
     void testGetConsumerVertexGroup() {
 
-        assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent();
+        assertThat(resultPartition.getConsumerVertexGroups()).isEmpty();
 
         // test update consumers
         ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
         consumerVertexGroups.put(
-                resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId));
-        assertThat(resultPartition.getConsumerVertexGroup()).isPresent();
-        assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId);
+                resultPartition.getId(),
+                Collections.singletonList(ConsumerVertexGroup.fromSingleVertex(executionVertexId)));
+        assertThat(resultPartition.getConsumerVertexGroups()).isNotEmpty();
+        assertThat(resultPartition.getConsumerVertexGroups().get(0)).contains(executionVertexId);
     }
 
     /** A test {@link ResultPartitionState} supplier. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index 3c9bed61354..a6e87f9bfc2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -166,7 +166,7 @@ class AdaptiveBatchSchedulerTest {
         sink.connectNewDataSetAsInput(
                 source2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
         if (withForwardEdge) {
-            source1.getProducedDataSets().get(0).getConsumer().setForward(true);
+            source1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
         return new JobGraph(new JobID(), "test job", source1, source2, sink);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
index b45ba3f5ebe..e9b0da1509d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -97,14 +97,14 @@ class ForwardGroupComputeUtilTest {
         v2.connectNewDataSetAsInput(
                 v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
         if (isForward1) {
-            v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+            v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
 
         v3.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
 
         if (isForward2) {
-            v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+            v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         }
 
         Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3);
@@ -135,10 +135,10 @@ class ForwardGroupComputeUtilTest {
 
         v3.connectNewDataSetAsInput(
                 v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
-        v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+        v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         v3.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
-        v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+        v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
         v4.connectNewDataSetAsInput(
                 v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
 
@@ -174,8 +174,8 @@ class ForwardGroupComputeUtilTest {
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
         v4.connectNewDataSetAsInput(
                 v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
-        v2.getProducedDataSets().get(0).getConsumer().setForward(true);
-        v2.getProducedDataSets().get(1).getConsumer().setForward(true);
+        v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
+        v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true);
 
         Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4);
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
index 4621680d360..655f725874c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
@@ -93,7 +93,9 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
     void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) {
         final ConsumedPartitionGroup consumedPartitionGroup =
                 ConsumedPartitionGroup.fromSinglePartition(
-                        consumedPartition.getId(), consumedPartition.getResultType());
+                        consumedPartition.getNumConsumers(),
+                        consumedPartition.getId(),
+                        consumedPartition.getResultType());
 
         consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
         if (consumedPartition.getState() == ResultPartitionState.CONSUMABLE) {
@@ -155,7 +157,8 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
                     partitionIds.add(partitionId);
                 }
                 this.consumedPartitionGroups.add(
-                        ConsumedPartitionGroup.fromMultiplePartitions(partitionIds, resultType));
+                        ConsumedPartitionGroup.fromMultiplePartitions(
+                                partitionGroup.getNumConsumers(), partitionIds, resultType));
             }
             return this;
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
index 9454268cb5f..6274eecb285 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
@@ -28,7 +28,6 @@ import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -64,6 +63,10 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
         this.consumedPartitionGroups = new ArrayList<>();
     }
 
+    public int getNumConsumers() {
+        return consumerVertexGroup == null ? 1 : consumerVertexGroup.size();
+    }
+
     @Override
     public IntermediateResultPartitionID getId() {
         return intermediateResultPartitionID;
@@ -90,8 +93,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
     }
 
     @Override
-    public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
-        return Optional.of(consumerVertexGroup);
+    public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+        return Collections.singletonList(consumerVertexGroup);
     }
 
     @Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
index 82463a1a2cb..95596369e33 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
@@ -350,6 +350,7 @@ public class TestingSchedulingTopology implements SchedulingTopology {
 
             ConsumedPartitionGroup consumedPartitionGroup =
                     ConsumedPartitionGroup.fromMultiplePartitions(
+                            consumers.size(),
                             resultPartitions.stream()
                                     .map(TestingSchedulingResultPartition::getId)
                                     .collect(Collectors.toList()),
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index be72cacff5c..4cf9ed7c990 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -1079,19 +1079,20 @@ public class StreamingJobGraphGenerator {
                             headVertex,
                             DistributionPattern.POINTWISE,
                             resultPartitionType,
-                            intermediateDataSetID);
+                            intermediateDataSetID,
+                            partitioner.isBroadcast());
         } else {
             jobEdge =
                     downStreamVertex.connectNewDataSetAsInput(
                             headVertex,
                             DistributionPattern.ALL_TO_ALL,
                             resultPartitionType,
-                            intermediateDataSetID);
+                            intermediateDataSetID,
+                            partitioner.isBroadcast());
         }
 
         // set strategy name so that web interface can show it.
         jobEdge.setShipStrategyName(partitioner.toString());
-        jobEdge.setBroadcast(partitioner.isBroadcast());
         jobEdge.setForward(partitioner instanceof ForwardPartitioner);
         jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper());
         jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper());
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
index ab50814edd2..137e5040383 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
@@ -55,24 +55,26 @@ public class TestJobDataFlowValidator {
 
         for (JobVertex upstream : testJob.jobGraph.getVertices()) {
             for (IntermediateDataSet produced : upstream.getProducedDataSets()) {
-                JobEdge edge = produced.getConsumer();
-                Optional<String> upstreamIDOptional = getTrackedOperatorID(upstream, true, testJob);
-                Optional<String> downstreamIDOptional =
-                        getTrackedOperatorID(edge.getTarget(), false, testJob);
-                if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
-                    final String upstreamID = upstreamIDOptional.get();
-                    final String downstreamID = downstreamIDOptional.get();
-                    if (testJob.sources.contains(upstreamID)) {
-                        // TODO: if we add tests for FLIP-27 sources we might need to adjust
-                        // this condition
-                        LOG.debug(
-                                "Legacy sources do not have the finish() method and thus do not"
-                                        + " emit FinishEvent");
+                for (JobEdge edge : produced.getConsumers()) {
+                    Optional<String> upstreamIDOptional =
+                            getTrackedOperatorID(upstream, true, testJob);
+                    Optional<String> downstreamIDOptional =
+                            getTrackedOperatorID(edge.getTarget(), false, testJob);
+                    if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
+                        final String upstreamID = upstreamIDOptional.get();
+                        final String downstreamID = downstreamIDOptional.get();
+                        if (testJob.sources.contains(upstreamID)) {
+                            // TODO: if we add tests for FLIP-27 sources we might need to adjust
+                            // this condition
+                            LOG.debug(
+                                    "Legacy sources do not have the finish() method and thus do not"
+                                            + " emit FinishEvent");
+                        } else {
+                            checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+                        }
                     } else {
-                        checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+                        LOG.debug("Ignoring edge (untracked operator): {}", edge);
                     }
-                } else {
-                    LOG.debug("Ignoring edge (untracked operator): {}", edge);
                 }
             }
         }


[flink] 01/02: [hotfix][tests] Migrate tests relevant to FLINK-28663 to Junit5/AssertJ

Posted by yi...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit b3be6bbd9c99fa988129497e13d7ae97de17c264
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Wed Jul 27 19:57:01 2022 +0800

    [hotfix][tests] Migrate tests relevant to FLINK-28663 to Junit5/AssertJ
    
    Migrated tests include DefaultExecutionGraphConstructionTest, EdgeManagerBuildUtilTest, EdgeManagerTest, ExecutionJobVertexTest, IntermediateResultPartitionTest, RemoveCachedShuffleDescriptorTest, JobTaskVertexTest, DefaultExecutionTopologyTest, DefaultExecutionVertexTest, DefaultResultPartitionTest, AdaptiveBatchSchedulerTest and ForwardGroupComputeUtilTest.
    
    This closes #20350.
---
 .../DefaultExecutionGraphConstructionTest.java     | 229 ++++++++++-----------
 .../executiongraph/EdgeManagerBuildUtilTest.java   |  28 +--
 .../runtime/executiongraph/EdgeManagerTest.java    |  24 ++-
 .../executiongraph/ExecutionJobVertexTest.java     | 128 +++++-------
 .../IntermediateResultPartitionTest.java           | 124 ++++++-----
 .../RemoveCachedShuffleDescriptorTest.java         |  82 ++++----
 .../flink/runtime/jobgraph/JobTaskVertexTest.java  | 154 ++++++--------
 .../adapter/DefaultExecutionTopologyTest.java      | 128 +++++-------
 .../adapter/DefaultExecutionVertexTest.java        |  25 ++-
 .../adapter/DefaultResultPartitionTest.java        |  29 ++-
 .../adaptivebatch/AdaptiveBatchSchedulerTest.java  |  38 ++--
 .../forwardgroup/ForwardGroupComputeUtilTest.java  |  37 ++--
 12 files changed, 462 insertions(+), 564 deletions(-)

diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index ee4383c8296..9245fa8f966 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -18,8 +18,6 @@
 
 package org.apache.flink.runtime.executiongraph;
 
-import org.apache.flink.api.common.JobID;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.io.InputSplit;
 import org.apache.flink.core.io.InputSplitAssigner;
 import org.apache.flink.core.io.InputSplitSource;
@@ -33,13 +31,12 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
 import org.apache.flink.shaded.guava30.com.google.common.collect.Sets;
 
-import org.junit.ClassRule;
-import org.junit.Test;
-import org.mockito.Matchers;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -51,24 +48,18 @@ import java.util.Objects;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.hamcrest.Matchers.empty;
-import static org.hamcrest.Matchers.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * This class contains test concerning the correct conversion from {@link JobGraph} to {@link
  * ExecutionGraph} objects. It also tests that {@link EdgeManagerBuildUtil#connectVertexToResult}
  * builds {@link DistributionPattern#ALL_TO_ALL} connections correctly.
  */
-public class DefaultExecutionGraphConstructionTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class DefaultExecutionGraphConstructionTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private ExecutionGraph createDefaultExecutionGraph(List<JobVertex> vertices) throws Exception {
         return TestingDefaultExecutionGraphBuilder.newBuilder()
@@ -83,7 +74,7 @@ public class DefaultExecutionGraphConstructionTest {
     }
 
     @Test
-    public void testExecutionAttemptIdInTwoIdenticalJobsIsNotSame() throws Exception {
+    void testExecutionAttemptIdInTwoIdenticalJobsIsNotSame() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -104,10 +95,10 @@ public class DefaultExecutionGraphConstructionTest {
         eg2.attachJobGraph(ordered);
 
         assertThat(
-                Sets.intersection(
-                        eg1.getRegisteredExecutions().keySet(),
-                        eg2.getRegisteredExecutions().keySet()),
-                is(empty()));
+                        Sets.intersection(
+                                eg1.getRegisteredExecutions().keySet(),
+                                eg2.getRegisteredExecutions().keySet()))
+                .isEmpty();
     }
 
     /**
@@ -124,7 +115,7 @@ public class DefaultExecutionGraphConstructionTest {
      * </pre>
      */
     @Test
-    public void testCreateSimpleGraphBipartite() throws Exception {
+    void testCreateSimpleGraphBipartite() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -157,13 +148,7 @@ public class DefaultExecutionGraphConstructionTest {
         List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v4, v5));
 
         ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-        try {
-            eg.attachJobGraph(ordered);
-        } catch (JobException e) {
-            e.printStackTrace();
-            fail("Job failed with exception: " + e.getMessage());
-        }
-
+        eg.attachJobGraph(ordered);
         verifyTestGraph(eg, v1, v2, v3, v4, v5);
     }
 
@@ -187,7 +172,7 @@ public class DefaultExecutionGraphConstructionTest {
     }
 
     @Test
-    public void testCannotConnectWrongOrder() throws Exception {
+    void testCannotConnectWrongOrder() throws Exception {
         JobVertex v1 = new JobVertex("vertex1");
         JobVertex v2 = new JobVertex("vertex2");
         JobVertex v3 = new JobVertex("vertex3");
@@ -220,88 +205,64 @@ public class DefaultExecutionGraphConstructionTest {
         List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v5, v4));
 
         ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-        try {
-            eg.attachJobGraph(ordered);
-            fail("Attached wrong jobgraph");
-        } catch (JobException e) {
-            // expected
-        }
+        assertThatThrownBy(() -> eg.attachJobGraph(ordered)).isInstanceOf(JobException.class);
     }
 
     @Test
-    public void testSetupInputSplits() {
-        try {
-            final InputSplit[] emptySplits = new InputSplit[0];
-
-            InputSplitAssigner assigner1 = mock(InputSplitAssigner.class);
-            InputSplitAssigner assigner2 = mock(InputSplitAssigner.class);
-
-            @SuppressWarnings("unchecked")
-            InputSplitSource<InputSplit> source1 = mock(InputSplitSource.class);
-            @SuppressWarnings("unchecked")
-            InputSplitSource<InputSplit> source2 = mock(InputSplitSource.class);
-
-            when(source1.createInputSplits(Matchers.anyInt())).thenReturn(emptySplits);
-            when(source2.createInputSplits(Matchers.anyInt())).thenReturn(emptySplits);
-            when(source1.getInputSplitAssigner(emptySplits)).thenReturn(assigner1);
-            when(source2.getInputSplitAssigner(emptySplits)).thenReturn(assigner2);
-
-            final JobID jobId = new JobID();
-            final String jobName = "Test Job Sample Name";
-            final Configuration cfg = new Configuration();
-
-            JobVertex v1 = new JobVertex("vertex1");
-            JobVertex v2 = new JobVertex("vertex2");
-            JobVertex v3 = new JobVertex("vertex3");
-            JobVertex v4 = new JobVertex("vertex4");
-            JobVertex v5 = new JobVertex("vertex5");
-
-            v1.setParallelism(5);
-            v2.setParallelism(7);
-            v3.setParallelism(2);
-            v4.setParallelism(11);
-            v5.setParallelism(4);
-
-            v1.setInvokableClass(AbstractInvokable.class);
-            v2.setInvokableClass(AbstractInvokable.class);
-            v3.setInvokableClass(AbstractInvokable.class);
-            v4.setInvokableClass(AbstractInvokable.class);
-            v5.setInvokableClass(AbstractInvokable.class);
-
-            v2.connectNewDataSetAsInput(
-                    v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v4.connectNewDataSetAsInput(
-                    v2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v4.connectNewDataSetAsInput(
-                    v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v5.connectNewDataSetAsInput(
-                    v4, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-            v5.connectNewDataSetAsInput(
-                    v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-
-            v3.setInputSplitSource(source1);
-            v5.setInputSplitSource(source2);
-
-            List<JobVertex> ordered = new ArrayList<JobVertex>(Arrays.asList(v1, v2, v3, v4, v5));
-
-            ExecutionGraph eg = createDefaultExecutionGraph(ordered);
-            try {
-                eg.attachJobGraph(ordered);
-            } catch (JobException e) {
-                e.printStackTrace();
-                fail("Job failed with exception: " + e.getMessage());
-            }
-
-            assertEquals(assigner1, eg.getAllVertices().get(v3.getID()).getSplitAssigner());
-            assertEquals(assigner2, eg.getAllVertices().get(v5.getID()).getSplitAssigner());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+    void testSetupInputSplits() throws Exception {
+        final InputSplit[] emptySplits = new InputSplit[0];
+
+        InputSplitAssigner assigner1 = new TestingInputSplitAssigner();
+        InputSplitAssigner assigner2 = new TestingInputSplitAssigner();
+
+        InputSplitSource<InputSplit> source1 =
+                new TestingInputSplitSource<>(emptySplits, assigner1);
+        InputSplitSource<InputSplit> source2 =
+                new TestingInputSplitSource<>(emptySplits, assigner2);
+
+        JobVertex v1 = new JobVertex("vertex1");
+        JobVertex v2 = new JobVertex("vertex2");
+        JobVertex v3 = new JobVertex("vertex3");
+        JobVertex v4 = new JobVertex("vertex4");
+        JobVertex v5 = new JobVertex("vertex5");
+
+        v1.setParallelism(5);
+        v2.setParallelism(7);
+        v3.setParallelism(2);
+        v4.setParallelism(11);
+        v5.setParallelism(4);
+
+        v1.setInvokableClass(AbstractInvokable.class);
+        v2.setInvokableClass(AbstractInvokable.class);
+        v3.setInvokableClass(AbstractInvokable.class);
+        v4.setInvokableClass(AbstractInvokable.class);
+        v5.setInvokableClass(AbstractInvokable.class);
+
+        v2.connectNewDataSetAsInput(
+                v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v4.connectNewDataSetAsInput(
+                v2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v4.connectNewDataSetAsInput(
+                v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v5.connectNewDataSetAsInput(
+                v4, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        v5.connectNewDataSetAsInput(
+                v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+
+        v3.setInputSplitSource(source1);
+        v5.setInputSplitSource(source2);
+
+        List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2, v3, v4, v5));
+
+        ExecutionGraph eg = createDefaultExecutionGraph(ordered);
+        eg.attachJobGraph(ordered);
+
+        assertThat(eg.getAllVertices().get(v3.getID()).getSplitAssigner()).isEqualTo(assigner1);
+        assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2);
     }
 
     @Test
-    public void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
+    void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -321,9 +282,8 @@ public class DefaultExecutionGraphConstructionTest {
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
 
-        assertEquals(
-                partition1.getConsumedPartitionGroups().get(0),
-                partition2.getConsumedPartitionGroups().get(0));
+        assertThat(partition2.getConsumedPartitionGroups().get(0))
+                .isEqualTo(partition1.getConsumedPartitionGroups().get(0));
 
         ConsumedPartitionGroup consumedPartitionGroup =
                 partition1.getConsumedPartitionGroups().get(0);
@@ -331,13 +291,13 @@ public class DefaultExecutionGraphConstructionTest {
         for (IntermediateResultPartitionID partitionId : consumedPartitionGroup) {
             partitionIds.add(partitionId);
         }
-        assertThat(
-                partitionIds,
-                containsInAnyOrder(partition1.getPartitionId(), partition2.getPartitionId()));
+        assertThat(partitionIds)
+                .containsExactlyInAnyOrder(
+                        partition1.getPartitionId(), partition2.getPartitionId());
     }
 
     @Test
-    public void testAttachToDynamicGraph() throws Exception {
+    void testAttachToDynamicGraph() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -351,9 +311,42 @@ public class DefaultExecutionGraphConstructionTest {
         ExecutionGraph eg = createDynamicExecutionGraph(ordered);
         eg.attachJobGraph(ordered);
 
-        assertThat(eg.getAllVertices().size(), is(2));
+        assertThat(eg.getAllVertices()).hasSize(2);
         Iterator<ExecutionJobVertex> jobVertices = eg.getVerticesTopologically().iterator();
-        assertThat(jobVertices.next().isInitialized(), is(false));
-        assertThat(jobVertices.next().isInitialized(), is(false));
+        assertThat(jobVertices.next().isInitialized()).isFalse();
+        assertThat(jobVertices.next().isInitialized()).isFalse();
+    }
+
+    private static final class TestingInputSplitAssigner implements InputSplitAssigner {
+
+        @Override
+        public InputSplit getNextInputSplit(String host, int taskId) {
+            return null;
+        }
+
+        @Override
+        public void returnInputSplit(List<InputSplit> splits, int taskId) {}
+    }
+
+    private static final class TestingInputSplitSource<T extends InputSplit>
+            implements InputSplitSource<T> {
+
+        private final T[] inputSplits;
+        private final InputSplitAssigner assigner;
+
+        private TestingInputSplitSource(T[] inputSplits, InputSplitAssigner assigner) {
+            this.inputSplits = inputSplits;
+            this.assigner = assigner;
+        }
+
+        @Override
+        public T[] createInputSplits(int minNumSplits) throws Exception {
+            return inputSplits;
+        }
+
+        @Override
+        public InputSplitAssigner getInputSplitAssigner(T[] inputSplits) {
+            return assigner;
+        }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index 2cd5a8d1ce3..e3b603a8790 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -25,11 +25,11 @@ import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
 import org.apache.commons.lang3.tuple.Pair;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -38,20 +38,20 @@ import java.util.concurrent.ScheduledExecutorService;
 
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for {@link EdgeManagerBuildUtil} to verify the max number of connecting edges between
  * vertices for pattern of both {@link DistributionPattern#POINTWISE} and {@link
  * DistributionPattern#ALL_TO_ALL}.
  */
-public class EdgeManagerBuildUtilTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class EdgeManagerBuildUtilTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testGetMaxNumEdgesToTargetInPointwiseConnection() throws Exception {
+    void testGetMaxNumEdgesToTargetInPointwiseConnection() throws Exception {
         testGetMaxNumEdgesToTarget(17, 17, POINTWISE);
         testGetMaxNumEdgesToTarget(17, 23, POINTWISE);
         testGetMaxNumEdgesToTarget(17, 34, POINTWISE);
@@ -60,7 +60,7 @@ public class EdgeManagerBuildUtilTest {
     }
 
     @Test
-    public void testGetMaxNumEdgesToTargetInAllToAllConnection() throws Exception {
+    void testGetMaxNumEdgesToTargetInAllToAllConnection() throws Exception {
         testGetMaxNumEdgesToTarget(17, 17, ALL_TO_ALL);
         testGetMaxNumEdgesToTarget(17, 23, ALL_TO_ALL);
         testGetMaxNumEdgesToTarget(17, 34, ALL_TO_ALL);
@@ -81,7 +81,7 @@ public class EdgeManagerBuildUtilTest {
                         upstream, downstream, pattern);
         int actualMaxForUpstream = -1;
         for (ExecutionVertex ev : upstreamEJV.getTaskVertices()) {
-            assertEquals(1, ev.getProducedPartitions().size());
+            assertThat(ev.getProducedPartitions()).hasSize(1);
 
             IntermediateResultPartition partition =
                     ev.getProducedPartitions().values().iterator().next();
@@ -91,21 +91,21 @@ public class EdgeManagerBuildUtilTest {
                 actualMaxForUpstream = actual;
             }
         }
-        assertEquals(actualMaxForUpstream, calculatedMaxForUpstream);
+        assertThat(actualMaxForUpstream).isEqualTo(calculatedMaxForUpstream);
 
         int calculatedMaxForDownstream =
                 EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
                         downstream, upstream, pattern);
         int actualMaxForDownstream = -1;
         for (ExecutionVertex ev : downstreamEJV.getTaskVertices()) {
-            assertEquals(1, ev.getNumberOfInputs());
+            assertThat(ev.getNumberOfInputs()).isEqualTo(1);
 
             int actual = ev.getConsumedPartitionGroup(0).size();
             if (actual > actualMaxForDownstream) {
                 actualMaxForDownstream = actual;
             }
         }
-        assertEquals(actualMaxForDownstream, calculatedMaxForDownstream);
+        assertThat(actualMaxForDownstream).isEqualTo(calculatedMaxForDownstream);
     }
 
     private Pair<ExecutionJobVertex, ExecutionJobVertex> setupExecutionGraph(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
index 323a753a511..81165da6193 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
@@ -29,25 +29,25 @@ import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Objects;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link EdgeManager}. */
-public class EdgeManagerTest {
+class EdgeManagerTest {
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testGetConsumedPartitionGroup() throws Exception {
+    void testGetConsumedPartitionGroup() throws Exception {
         JobVertex v1 = new JobVertex("source");
         JobVertex v2 = new JobVertex("sink");
 
@@ -82,7 +82,8 @@ public class EdgeManagerTest {
         ConsumedPartitionGroup groupRetrievedByIntermediateResultPartition =
                 consumedPartition.getConsumedPartitionGroups().get(0);
 
-        assertEquals(groupRetrievedByDownstreamVertex, groupRetrievedByIntermediateResultPartition);
+        assertThat(groupRetrievedByIntermediateResultPartition)
+                .isEqualTo(groupRetrievedByDownstreamVertex);
 
         ConsumedPartitionGroup groupRetrievedByScheduledResultPartition =
                 scheduler
@@ -92,6 +93,7 @@ public class EdgeManagerTest {
                         .getConsumedPartitionGroups()
                         .get(0);
 
-        assertEquals(groupRetrievedByDownstreamVertex, groupRetrievedByScheduledResultPartition);
+        assertThat(groupRetrievedByScheduledResultPartition)
+                .isEqualTo(groupRetrievedByDownstreamVertex);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index 8f2a35cece5..d847e9ede68 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -29,164 +29,128 @@ import org.apache.flink.runtime.scheduler.VertexParallelismInformation;
 import org.apache.flink.runtime.scheduler.VertexParallelismStore;
 import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.Assert;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Collections;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static org.apache.flink.core.testutils.CommonTestUtils.assertThrows;
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Test for {@link ExecutionJobVertex} */
-public class ExecutionJobVertexTest {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class ExecutionJobVertexTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testParallelismGreaterThanMaxParallelism() {
+    void testParallelismGreaterThanMaxParallelism() {
         JobVertex jobVertex = new JobVertex("testVertex");
         jobVertex.setInvokableClass(AbstractInvokable.class);
         // parallelism must be smaller than the max parallelism
         jobVertex.setParallelism(172);
         jobVertex.setMaxParallelism(4);
 
-        assertThrows(
-                "higher than the max parallelism",
-                JobException.class,
-                () -> ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex));
+        assertThatThrownBy(() -> ExecutionGraphTestUtils.getExecutionJobVertex(jobVertex))
+                .isInstanceOf(JobException.class)
+                .hasMessageContaining("higher than the max parallelism");
     }
 
     @Test
-    public void testLazyInitialization() throws Exception {
+    void testLazyInitialization() throws Exception {
         final int parallelism = 3;
         final int configuredMaxParallelism = 12;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(parallelism, configuredMaxParallelism, -1);
 
-        assertThat(ejv.getParallelism(), is(parallelism));
-        assertThat(ejv.getMaxParallelism(), is(configuredMaxParallelism));
-        assertThat(ejv.isInitialized(), is(false));
+        assertThat(ejv.getParallelism()).isEqualTo(parallelism);
+        assertThat(ejv.getMaxParallelism()).isEqualTo(configuredMaxParallelism);
+        assertThat(ejv.isInitialized()).isFalse();
 
-        assertThat(ejv.getTaskVertices().length, is(0));
+        assertThat(ejv.getTaskVertices()).isEmpty();
 
-        try {
-            ejv.getInputs();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getInputs).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getProducedDataSets();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getProducedDataSets).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getSplitAssigner();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getSplitAssigner).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.getOperatorCoordinators();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::getOperatorCoordinators).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.connectToPredecessors(Collections.emptyMap());
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(() -> ejv.connectToPredecessors(Collections.emptyMap()))
+                .isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.executionVertexFinished();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::executionVertexFinished).isInstanceOf(IllegalStateException.class);
 
-        try {
-            ejv.executionVertexUnFinished();
-            Assert.fail("failure is expected");
-        } catch (IllegalStateException e) {
-            // ignore
-        }
+        assertThatThrownBy(ejv::executionVertexUnFinished)
+                .isInstanceOf(IllegalStateException.class);
 
         initializeVertex(ejv);
 
-        assertThat(ejv.isInitialized(), is(true));
-        assertThat(ejv.getTaskVertices().length, is(3));
-        assertThat(ejv.getInputs().size(), is(0));
-        assertThat(ejv.getProducedDataSets().length, is(1));
-        assertThat(ejv.getOperatorCoordinators().size(), is(0));
+        assertThat(ejv.isInitialized()).isTrue();
+        assertThat(ejv.getTaskVertices()).hasSize(3);
+        assertThat(ejv.getInputs()).isEmpty();
+        assertThat(ejv.getProducedDataSets()).hasSize(1);
+        assertThat(ejv.getOperatorCoordinators()).isEmpty();
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfInitializationWithoutParallelismDecided() throws Exception {
+    @Test
+    void testErrorIfInitializationWithoutParallelismDecided() throws Exception {
         final ExecutionJobVertex ejv = createDynamicExecutionJobVertex();
 
-        initializeVertex(ejv);
+        assertThatThrownBy(() -> initializeVertex(ejv)).isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testSetParallelismLazily() throws Exception {
+    void testSetParallelismLazily() throws Exception {
         final int parallelism = 3;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(-1, -1, defaultMaxParallelism);
 
-        assertThat(ejv.isParallelismDecided(), is(false));
+        assertThat(ejv.isParallelismDecided()).isFalse();
 
         ejv.setParallelism(parallelism);
 
-        assertThat(ejv.isParallelismDecided(), is(true));
-        assertThat(ejv.getParallelism(), is(parallelism));
+        assertThat(ejv.isParallelismDecided()).isTrue();
+        assertThat(ejv.getParallelism()).isEqualTo(parallelism);
 
         initializeVertex(ejv);
 
-        assertThat(ejv.getTaskVertices().length, is(parallelism));
+        assertThat(ejv.getTaskVertices()).hasSize(parallelism);
     }
 
     @Test
-    public void testConfiguredMaxParallelismIsRespected() throws Exception {
+    void testConfiguredMaxParallelismIsRespected() throws Exception {
         final int configuredMaxParallelism = 12;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(
                         -1, configuredMaxParallelism, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(configuredMaxParallelism));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(configuredMaxParallelism);
     }
 
     @Test
-    public void testComputingMaxParallelismFromConfiguredParallelism() throws Exception {
+    void testComputingMaxParallelismFromConfiguredParallelism() throws Exception {
         final int parallelism = 300;
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(parallelism, -1, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(512));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(512);
     }
 
     @Test
-    public void testFallingBackToDefaultMaxParallelism() throws Exception {
+    void testFallingBackToDefaultMaxParallelism() throws Exception {
         final int defaultMaxParallelism = 13;
         final ExecutionJobVertex ejv =
                 createDynamicExecutionJobVertex(-1, -1, defaultMaxParallelism);
 
-        assertThat(ejv.getMaxParallelism(), is(defaultMaxParallelism));
+        assertThat(ejv.getMaxParallelism()).isEqualTo(defaultMaxParallelism);
     }
 
     static void initializeVertex(ExecutionJobVertex vertex) throws Exception {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 456de025668..5ff5d9f201c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -34,11 +34,10 @@ import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.Iterator;
@@ -46,42 +45,37 @@ import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.CoreMatchers.is;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.equalTo;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link IntermediateResultPartition}. */
-public class IntermediateResultPartitionTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+public class IntermediateResultPartitionTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     @Test
-    public void testPipelinedPartitionConsumable() throws Exception {
+    void testPipelinedPartitionConsumable() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.PIPELINED, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Partition 1 consumable after data are produced
         partition1.markDataProduced();
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Not consumable if failover happens
         result.resetForNewExecution();
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
     }
 
     @Test
-    public void testBlockingPartitionConsumable() throws Exception {
+    void testBlockingPartitionConsumable() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
@@ -90,31 +84,31 @@ public class IntermediateResultPartitionTest extends TestLogger {
                 partition1.getConsumedPartitionGroups().get(0);
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Not consumable if only one partition is FINISHED
         partition1.markFinished();
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Consumable after all partitions are FINISHED
         partition2.markFinished();
-        assertTrue(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertTrue(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isTrue();
 
         // Not consumable if failover happens
         result.resetForNewExecution();
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
     @Test
-    public void testBlockingPartitionResetting() throws Exception {
+    void testBlockingPartitionResetting() throws Exception {
         IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
         IntermediateResultPartition partition1 = result.getPartitions()[0];
         IntermediateResultPartition partition2 = result.getPartitions()[1];
@@ -123,71 +117,71 @@ public class IntermediateResultPartitionTest extends TestLogger {
                 partition1.getConsumedPartitionGroups().get(0);
 
         // Not consumable on init
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
 
         // Not consumable if partition1 is FINISHED
         partition1.markFinished();
-        assertEquals(1, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertTrue(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(1);
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Reset the result and mark partition2 FINISHED, the result should still not be consumable
         result.resetForNewExecution();
-        assertEquals(2, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(2);
         partition2.markFinished();
-        assertEquals(1, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertFalse(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(1);
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
 
         // Consumable after all partitions are FINISHED
         partition1.markFinished();
-        assertEquals(0, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertTrue(partition1.isConsumable());
-        assertTrue(partition2.isConsumable());
-        assertTrue(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(0);
+        assertThat(partition1.isConsumable()).isTrue();
+        assertThat(partition2.isConsumable()).isTrue();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isTrue();
 
         // Not consumable again if failover happens
         result.resetForNewExecution();
-        assertEquals(2, consumedPartitionGroup.getNumberOfUnfinishedPartitions());
-        assertFalse(partition1.isConsumable());
-        assertFalse(partition2.isConsumable());
-        assertFalse(consumedPartitionGroup.areAllPartitionsFinished());
+        assertThat(consumedPartitionGroup.getNumberOfUnfinishedPartitions()).isEqualTo(2);
+        assertThat(partition1.isConsumable()).isFalse();
+        assertThat(partition2.isConsumable()).isFalse();
+        assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
+    void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsForNonDynamicPointwiseGraph() throws Exception {
+    void testGetNumberOfSubpartitionsForNonDynamicPointwiseGraph() throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.POINTWISE, false, Arrays.asList(4, 3));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicAllToAllGraph()
+    void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicAllToAllGraph()
             throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, true, Arrays.asList(7, 7));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicPointwiseGraph()
+    void testGetNumberOfSubpartitionsFromConsumerParallelismForDynamicPointwiseGraph()
             throws Exception {
         testGetNumberOfSubpartitions(7, DistributionPattern.POINTWISE, true, Arrays.asList(4, 4));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicAllToAllGraph()
+    void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicAllToAllGraph()
             throws Exception {
         testGetNumberOfSubpartitions(
                 -1, DistributionPattern.ALL_TO_ALL, true, Arrays.asList(13, 13));
     }
 
     @Test
-    public void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicPointwiseGraph()
+    void testGetNumberOfSubpartitionsFromConsumerMaxParallelismForDynamicPointwiseGraph()
             throws Exception {
         testGetNumberOfSubpartitions(-1, DistributionPattern.POINTWISE, true, Arrays.asList(7, 7));
     }
@@ -221,12 +215,12 @@ public class IntermediateResultPartitionTest extends TestLogger {
 
         final IntermediateResult result = producer.getProducedDataSets()[0];
 
-        assertThat(expectedNumSubpartitions.size(), is(producerParallelism));
+        assertThat(expectedNumSubpartitions).hasSize(producerParallelism);
         assertThat(
-                Arrays.stream(result.getPartitions())
-                        .map(IntermediateResultPartition::getNumberOfSubpartitions)
-                        .collect(Collectors.toList()),
-                equalTo(expectedNumSubpartitions));
+                        Arrays.stream(result.getPartitions())
+                                .map(IntermediateResultPartition::getNumberOfSubpartitions)
+                                .collect(Collectors.toList()))
+                .isEqualTo(expectedNumSubpartitions);
     }
 
     public static ExecutionGraph createExecutionGraph(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
index 73dd9cdeb3a..7c00da22a85 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
@@ -43,13 +43,12 @@ import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
 import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.After;
-import org.junit.Before;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -62,28 +61,27 @@ import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeoutException;
 
 import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /**
  * Tests for removing cached {@link ShuffleDescriptor}s when the related partitions are no longer
  * valid. Currently, there are two scenarios as illustrated in {@link
  * IntermediateResult#clearCachedInformationForPartitionGroup}.
  */
-public class RemoveCachedShuffleDescriptorTest extends TestLogger {
+class RemoveCachedShuffleDescriptorTest {
 
     private static final int PARALLELISM = 4;
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private ScheduledExecutorService scheduledExecutorService;
     private ComponentMainThreadExecutor mainThreadExecutor;
     private ManuallyTriggeredScheduledExecutorService ioExecutor;
 
-    @Before
-    public void setup() {
+    @BeforeEach
+    void setup() {
         scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
         mainThreadExecutor =
                 ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(
@@ -91,21 +89,21 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor = new ManuallyTriggeredScheduledExecutorService();
     }
 
-    @After
-    public void teardown() {
+    @AfterEach
+    void teardown() {
         if (scheduledExecutorService != null) {
             scheduledExecutorService.shutdownNow();
         }
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
+    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
         // Here we expect no offloaded BLOB.
         testRemoveCacheForAllToAllEdgeAfterFinished(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
+    void testRemoveOffloadedCacheForAllToAllEdgeAfterFinished() throws Exception {
         // Here we expect 4 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the ALL-TO-ALL
         // edge (1).
@@ -129,8 +127,8 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(PARALLELISM, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(PARALLELISM);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         // For the all-to-all edge, we transition all downstream tasks to finished
         CompletableFuture.runAsync(
@@ -140,17 +138,17 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor.triggerAll();
 
         // Cache should be removed since partitions are released
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2));
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
+    void testRemoveNonOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
         testRemoveCacheForAllToAllEdgeAfterFailover(new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
+    void testRemoveOffloadedCacheForAllToAllEdgeAfterFailover() throws Exception {
         // Here we expect 4 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the ALL-TO-ALL
         // edge (1).
@@ -174,25 +172,25 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(PARALLELISM, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(PARALLELISM);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         triggerGlobalFailoverAndComplete(scheduler, v1);
         ioExecutor.triggerAll();
 
         // Cache should be removed during ExecutionVertex#resetForNewExecution
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2));
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2)).isNull();
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
+    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
         testRemoveCacheForPointwiseEdgeAfterFinished(
                 new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
+    void testRemoveOffloadedCacheForPointwiseEdgeAfterFinished() throws Exception {
         // Here we expect 7 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the POINTWISE
         // edges (4).
@@ -216,8 +214,8 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(1, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(1);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         // For the pointwise edge, we just transition the first downstream task to FINISHED
         ExecutionVertex ev21 =
@@ -229,7 +227,7 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         ioExecutor.triggerAll();
 
         // The cache of the first upstream task should be removed since its partition is released
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0));
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
 
         // The cache of the other upstream tasks should stay
         final ShuffleDescriptor[] shuffleDescriptorsForOtherVertex =
@@ -237,19 +235,19 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
                         getConsumedCachedShuffleDescriptor(executionGraph, v2, 1),
                         jobId,
                         blobWriter);
-        assertEquals(1, shuffleDescriptorsForOtherVertex.length);
+        assertThat(shuffleDescriptorsForOtherVertex).hasSize(1);
 
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     @Test
-    public void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
+    void testRemoveNonOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
         testRemoveCacheForPointwiseEdgeAfterFailover(
                 new TestingBlobWriter(Integer.MAX_VALUE), 0, 0);
     }
 
     @Test
-    public void testRemoveOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
+    void testRemoveOffloadedCacheForPointwiseEdgeAfterFailover() throws Exception {
         // Here we expect 7 offloaded BLOBs:
         // JobInformation (1) + TaskInformation (2) + Cache of ShuffleDescriptors for the POINTWISE
         // edges (4).
@@ -273,15 +271,15 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
         final ShuffleDescriptor[] shuffleDescriptors =
                 deserializeShuffleDescriptors(
                         getConsumedCachedShuffleDescriptor(executionGraph, v2), jobId, blobWriter);
-        assertEquals(1, shuffleDescriptors.length);
-        assertEquals(expectedBefore, blobWriter.numberOfBlobs());
+        assertThat(shuffleDescriptors).hasSize(1);
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedBefore);
 
         triggerExceptionAndComplete(executionGraph, v1, v2);
         ioExecutor.triggerAll();
 
         // The cache of the first upstream task should be removed during
         // ExecutionVertex#resetForNewExecution
-        assertNull(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0));
+        assertThat(getConsumedCachedShuffleDescriptor(executionGraph, v2, 0)).isNull();
 
         // The cache of the other upstream tasks should stay
         final ShuffleDescriptor[] shuffleDescriptorsForOtherVertex =
@@ -289,9 +287,9 @@ public class RemoveCachedShuffleDescriptorTest extends TestLogger {
                         getConsumedCachedShuffleDescriptor(executionGraph, v2, 1),
                         jobId,
                         blobWriter);
-        assertEquals(1, shuffleDescriptorsForOtherVertex.length);
+        assertThat(shuffleDescriptorsForOtherVertex).hasSize(1);
 
-        assertEquals(expectedAfter, blobWriter.numberOfBlobs());
+        assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
     }
 
     private DefaultScheduler createSchedulerAndDeploy(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 47cc722fa9b..66e097b34b1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -29,117 +29,89 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.operators.util.TaskConfig;
 import org.apache.flink.util.InstantiationUtil;
 
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
 
 import java.io.IOException;
 import java.net.URL;
 import java.net.URLClassLoader;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 @SuppressWarnings("serial")
-public class JobTaskVertexTest {
+class JobTaskVertexTest {
 
     @Test
-    public void testConnectDirectly() {
+    void testConnectDirectly() {
         JobVertex source = new JobVertex("source");
         JobVertex target = new JobVertex("target");
         target.connectNewDataSetAsInput(
                 source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
 
-        assertTrue(source.isInputVertex());
-        assertFalse(source.isOutputVertex());
-        assertFalse(target.isInputVertex());
-        assertTrue(target.isOutputVertex());
+        assertThat(source.isInputVertex()).isTrue();
+        assertThat(source.isOutputVertex()).isFalse();
+        assertThat(target.isInputVertex()).isFalse();
+        assertThat(target.isOutputVertex()).isTrue();
 
-        assertEquals(1, source.getNumberOfProducedIntermediateDataSets());
-        assertEquals(1, target.getNumberOfInputs());
+        assertThat(source.getNumberOfProducedIntermediateDataSets()).isEqualTo(1);
+        assertThat(target.getNumberOfInputs()).isEqualTo(1);
 
-        assertEquals(target.getInputs().get(0).getSource(), source.getProducedDataSets().get(0));
+        assertThat(source.getProducedDataSets().get(0))
+                .isEqualTo(target.getInputs().get(0).getSource());
 
-        assertEquals(target, source.getProducedDataSets().get(0).getConsumer().getTarget());
+        assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target);
     }
 
     @Test
-    public void testOutputFormat() {
-        try {
-            final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
-
-            OperatorID operatorID = new OperatorID();
-            Configuration parameters = new Configuration();
-            parameters.setString("test_key", "test_value");
-            new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
-                    .addOutputFormat(operatorID, new TestingOutputFormat(parameters))
-                    .addParameters(operatorID, parameters)
-                    .write(new TaskConfig(vertex.getConfiguration()));
-
-            final ClassLoader cl = new TestClassLoader();
-
-            try {
-                vertex.initializeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
+    void testOutputFormat() throws Exception {
+        final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
 
-            InputOutputFormatVertex copy = InstantiationUtil.clone(vertex);
-            ClassLoader ctxCl = Thread.currentThread().getContextClassLoader();
-            try {
-                copy.initializeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
-            assertEquals(
-                    "Previous classloader was not restored.",
-                    ctxCl,
-                    Thread.currentThread().getContextClassLoader());
-
-            try {
-                copy.finalizeOnMaster(cl);
-                fail("Did not throw expected exception.");
-            } catch (TestException e) {
-                // all good
-            }
-            assertEquals(
-                    "Previous classloader was not restored.",
-                    ctxCl,
-                    Thread.currentThread().getContextClassLoader());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+        OperatorID operatorID = new OperatorID();
+        Configuration parameters = new Configuration();
+        parameters.setString("test_key", "test_value");
+        new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
+                .addOutputFormat(operatorID, new TestingOutputFormat(parameters))
+                .addParameters(operatorID, parameters)
+                .write(new TaskConfig(vertex.getConfiguration()));
+
+        final ClassLoader cl = new TestClassLoader();
+
+        assertThatThrownBy(() -> vertex.initializeOnMaster(cl)).isInstanceOf(TestException.class);
+
+        InputOutputFormatVertex copy = InstantiationUtil.clone(vertex);
+        ClassLoader ctxCl = Thread.currentThread().getContextClassLoader();
+        assertThatThrownBy(() -> copy.initializeOnMaster(cl)).isInstanceOf(TestException.class);
+
+        assertThat(Thread.currentThread().getContextClassLoader())
+                .as("Previous classloader was not restored.")
+                .isEqualTo(ctxCl);
+
+        assertThatThrownBy(() -> copy.finalizeOnMaster(cl)).isInstanceOf(TestException.class);
+        assertThat(Thread.currentThread().getContextClassLoader())
+                .as("Previous classloader was not restored.")
+                .isEqualTo(ctxCl);
     }
 
     @Test
-    public void testInputFormat() {
-        try {
-            final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
-
-            OperatorID operatorID = new OperatorID();
-            Configuration parameters = new Configuration();
-            parameters.setString("test_key", "test_value");
-            new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
-                    .addInputFormat(operatorID, new TestInputFormat(parameters))
-                    .addParameters(operatorID, "test_key", "test_value")
-                    .write(new TaskConfig(vertex.getConfiguration()));
-
-            final ClassLoader cl = new TestClassLoader();
-
-            vertex.initializeOnMaster(cl);
-            InputSplit[] splits = vertex.getInputSplitSource().createInputSplits(77);
-
-            assertNotNull(splits);
-            assertEquals(1, splits.length);
-            assertEquals(TestSplit.class, splits[0].getClass());
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail(e.getMessage());
-        }
+    void testInputFormat() throws Exception {
+        final InputOutputFormatVertex vertex = new InputOutputFormatVertex("Name");
+
+        OperatorID operatorID = new OperatorID();
+        Configuration parameters = new Configuration();
+        parameters.setString("test_key", "test_value");
+        new InputOutputFormatContainer(Thread.currentThread().getContextClassLoader())
+                .addInputFormat(operatorID, new TestInputFormat(parameters))
+                .addParameters(operatorID, "test_key", "test_value")
+                .write(new TaskConfig(vertex.getConfiguration()));
+
+        final ClassLoader cl = new TestClassLoader();
+
+        vertex.initializeOnMaster(cl);
+        InputSplit[] splits = vertex.getInputSplitSource().createInputSplits(77);
+
+        assertThat(splits).isNotNull();
+        assertThat(splits).hasSize(1);
+        assertThat(splits[0].getClass()).isEqualTo(TestSplit.class);
     }
 
     // --------------------------------------------------------------------------------------------
@@ -191,8 +163,8 @@ public class JobTaskVertexTest {
                 throw new IllegalStateException("Context ClassLoader was not correctly switched.");
             }
             for (String key : expectedParameters.keySet()) {
-                assertEquals(
-                        expectedParameters.getString(key, null), parameters.getString(key, null));
+                assertThat(parameters.getString(key, null))
+                        .isEqualTo(expectedParameters.getString(key, null));
             }
             isConfigured = true;
         }
@@ -244,8 +216,8 @@ public class JobTaskVertexTest {
                 throw new IllegalStateException("Context ClassLoader was not correctly switched.");
             }
             for (String key : expectedParameters.keySet()) {
-                assertEquals(
-                        expectedParameters.getString(key, null), parameters.getString(key, null));
+                assertThat(parameters.getString(key, null))
+                        .isEqualTo(expectedParameters.getString(key, null));
             }
             isConfigured = true;
         }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
index 109484fc1a1..4e1d9cb76a4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
@@ -37,16 +37,14 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingPipelinedRegion;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 import org.apache.flink.util.IterableUtils;
-import org.apache.flink.util.TestLogger;
 
-import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
 import org.apache.flink.shaded.guava30.com.google.common.collect.Sets;
 
-import org.junit.Before;
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -58,31 +56,26 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 
-import static junit.framework.TestCase.assertSame;
-import static junit.framework.TestCase.assertTrue;
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createExecutionGraph;
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createNoOpVertex;
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.PIPELINED;
 import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /** Unit tests for {@link DefaultExecutionTopology}. */
-public class DefaultExecutionTopologyTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class DefaultExecutionTopologyTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private DefaultExecutionGraph executionGraph;
 
     private DefaultExecutionTopology adapter;
 
-    @Before
-    public void setUp() throws Exception {
+    @BeforeEach
+    void setUp() throws Exception {
         JobVertex[] jobVertices = new JobVertex[2];
         int parallelism = 3;
         jobVertices[0] = createNoOpVertex(parallelism);
@@ -93,13 +86,13 @@ public class DefaultExecutionTopologyTest extends TestLogger {
     }
 
     @Test
-    public void testConstructor() {
+    void testConstructor() {
         // implicitly tests order constraint of getVertices()
         assertGraphEquals(executionGraph, adapter);
     }
 
     @Test
-    public void testGetResultPartition() {
+    void testGetResultPartition() {
         for (ExecutionVertex vertex : executionGraph.getAllExecutionVertices()) {
             for (Map.Entry<IntermediateResultPartitionID, IntermediateResultPartition> entry :
                     vertex.getProducedPartitions().entrySet()) {
@@ -113,7 +106,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
     }
 
     @Test
-    public void testResultPartitionStateSupplier() {
+    void testResultPartitionStateSupplier() {
         final IntermediateResultPartition intermediateResultPartition =
                 IterableUtils.toStream(executionGraph.getAllExecutionVertices())
                         .flatMap(v -> v.getProducedPartitions().values().stream())
@@ -123,41 +116,33 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         final DefaultResultPartition schedulingResultPartition =
                 adapter.getResultPartition(intermediateResultPartition.getPartitionId());
 
-        assertEquals(ResultPartitionState.CREATED, schedulingResultPartition.getState());
+        assertThat(schedulingResultPartition.getState()).isEqualTo(ResultPartitionState.CREATED);
 
         intermediateResultPartition.markDataProduced();
-        assertEquals(ResultPartitionState.CONSUMABLE, schedulingResultPartition.getState());
+        assertThat(schedulingResultPartition.getState()).isEqualTo(ResultPartitionState.CONSUMABLE);
     }
 
     @Test
-    public void testGetVertexOrThrow() {
-        try {
-            adapter.getVertex(new ExecutionVertexID(new JobVertexID(), 0));
-            fail("get not exist vertex");
-        } catch (IllegalArgumentException exception) {
-            // expected
-        }
+    void testGetVertexOrThrow() {
+        assertThatThrownBy(() -> adapter.getVertex(new ExecutionVertexID(new JobVertexID(), 0)))
+                .isInstanceOf(IllegalArgumentException.class);
     }
 
     @Test
-    public void testResultPartitionOrThrow() {
-        try {
-            adapter.getResultPartition(new IntermediateResultPartitionID());
-            fail("get not exist result partition");
-        } catch (IllegalArgumentException exception) {
-            // expected
-        }
+    void testResultPartitionOrThrow() {
+        assertThatThrownBy(() -> adapter.getResultPartition(new IntermediateResultPartitionID()))
+                .isInstanceOf(IllegalArgumentException.class);
     }
 
     @Test
-    public void testGetAllPipelinedRegions() {
+    void testGetAllPipelinedRegions() {
         final Iterable<DefaultSchedulingPipelinedRegion> allPipelinedRegions =
                 adapter.getAllPipelinedRegions();
-        assertEquals(1, Iterables.size(allPipelinedRegions));
+        assertThat(allPipelinedRegions).hasSize(1);
     }
 
     @Test
-    public void testGetPipelinedRegionOfVertex() {
+    void testGetPipelinedRegionOfVertex() {
         for (DefaultExecutionVertex vertex : adapter.getVertices()) {
             final DefaultSchedulingPipelinedRegion pipelinedRegion =
                     adapter.getPipelinedRegionOfVertex(vertex.getId());
@@ -165,8 +150,8 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         }
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfCoLocatedTasksAreNotInSameRegion() throws Exception {
+    @Test
+    void testErrorIfCoLocatedTasksAreNotInSameRegion() throws Exception {
         int parallelism = 3;
         final JobVertex v1 = createNoOpVertex(parallelism);
         final JobVertex v2 = createNoOpVertex(parallelism);
@@ -176,13 +161,12 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         v2.setSlotSharingGroup(slotSharingGroup);
         v1.setStrictlyCoLocatedWith(v2);
 
-        final DefaultExecutionGraph executionGraph =
-                createExecutionGraph(EXECUTOR_RESOURCE.getExecutor(), v1, v2);
-        DefaultExecutionTopology.fromExecutionGraph(executionGraph);
+        assertThatThrownBy(() -> createExecutionGraph(EXECUTOR_RESOURCE.getExecutor(), v1, v2))
+                .isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testUpdateTopology() throws Exception {
+    void testUpdateTopology() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(BLOCKING);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -192,18 +176,17 @@ public class DefaultExecutionTopologyTest extends TestLogger {
 
         executionGraph.initializeJobVertex(ejv1, 0L);
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv1));
-        assertThat(IterableUtils.toStream(adapter.getVertices()).count(), is(3L));
+        assertThat(adapter.getVertices()).hasSize(3);
 
         executionGraph.initializeJobVertex(ejv2, 0L);
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv2));
-        assertThat(IterableUtils.toStream(adapter.getVertices()).count(), is(6L));
+        assertThat(adapter.getVertices()).hasSize(6);
 
         assertGraphEquals(executionGraph, adapter);
     }
 
-    @Test(expected = IllegalStateException.class)
-    public void testErrorIfUpdateTopologyWithNewVertexPipelinedConnectedToOldOnes()
-            throws Exception {
+    @Test
+    void testErrorIfUpdateTopologyWithNewVertexPipelinedConnectedToOldOnes() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(PIPELINED);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -215,11 +198,15 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv1));
 
         executionGraph.initializeJobVertex(ejv2, 0L);
-        adapter.notifyExecutionGraphUpdated(executionGraph, Collections.singletonList(ejv2));
+        assertThatThrownBy(
+                        () ->
+                                adapter.notifyExecutionGraphUpdated(
+                                        executionGraph, Collections.singletonList(ejv2)))
+                .isInstanceOf(IllegalStateException.class);
     }
 
     @Test
-    public void testExistingRegionsAreNotAffectedDuringTopologyUpdate() throws Exception {
+    void testExistingRegionsAreNotAffectedDuringTopologyUpdate() throws Exception {
         final JobVertex[] jobVertices = createJobVertices(BLOCKING);
         executionGraph = createDynamicGraph(jobVertices);
         adapter = DefaultExecutionTopology.fromExecutionGraph(executionGraph);
@@ -237,7 +224,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
         SchedulingPipelinedRegion regionNew =
                 adapter.getPipelinedRegionOfVertex(new ExecutionVertexID(ejv1.getJobVertexId(), 0));
 
-        assertSame(regionOld, regionNew);
+        assertThat(regionNew).isSameAs(regionOld);
     }
 
     private JobVertex[] createJobVertices(ResultPartitionType resultPartitionType) {
@@ -260,7 +247,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             final DefaultSchedulingPipelinedRegion pipelinedRegionOfVertex) {
         final Set<DefaultExecutionVertex> allVertices =
                 Sets.newHashSet(pipelinedRegionOfVertex.getVertices());
-        assertEquals(Sets.newHashSet(adapter.getVertices()), allVertices);
+        assertThat(allVertices).isEqualTo(Sets.newHashSet(adapter.getVertices()));
     }
 
     private static void assertGraphEquals(
@@ -274,7 +261,7 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             ExecutionVertex originalVertex = originalVertices.next();
             DefaultExecutionVertex adaptedVertex = adaptedVertices.next();
 
-            assertEquals(originalVertex.getID(), adaptedVertex.getId());
+            assertThat(adaptedVertex.getId()).isEqualTo(originalVertex.getID());
 
             List<IntermediateResultPartition> originalConsumedPartitions = new ArrayList<>();
             for (ConsumedPartitionGroup consumedPartitionGroup :
@@ -300,17 +287,16 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             assertPartitionsEquals(originalProducedPartitions, adaptedProducedPartitions);
         }
 
-        assertFalse(
-                "Number of adapted vertices exceeds number of original vertices.",
-                adaptedVertices.hasNext());
+        assertThat(adaptedVertices)
+                .as("Number of adapted vertices exceeds number of original vertices.")
+                .isExhausted();
     }
 
     private static void assertPartitionsEquals(
             Iterable<IntermediateResultPartition> originalResultPartitions,
             Iterable<DefaultResultPartition> adaptedResultPartitions) {
 
-        assertEquals(
-                Iterables.size(originalResultPartitions), Iterables.size(adaptedResultPartitions));
+        assertThat(originalResultPartitions).hasSameSizeAs(adaptedResultPartitions);
 
         for (IntermediateResultPartition originalPartition : originalResultPartitions) {
             DefaultResultPartition adaptedPartition =
@@ -331,16 +317,14 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup();
             Optional<ConsumerVertexGroup> adaptedConsumers =
                     adaptedPartition.getConsumerVertexGroup();
-            assertTrue(adaptedConsumers.isPresent());
+            assertThat(adaptedConsumers).isPresent();
             for (ExecutionVertexID originalId : consumerVertexGroup) {
                 // it is sufficient to verify that some vertex exists with the correct ID here,
                 // since deep equality is verified later in the main loop
                 // this DOES rely on an implicit assumption that the vertices objects returned by
                 // the topology are
                 // identical to those stored in the partition
-                assertTrue(
-                        IterableUtils.toStream(adaptedConsumers.get())
-                                .anyMatch(adaptedConsumer -> adaptedConsumer.equals(originalId)));
+                assertThat(adaptedConsumers.get()).contains(originalId);
             }
         }
     }
@@ -349,11 +333,11 @@ public class DefaultExecutionTopologyTest extends TestLogger {
             IntermediateResultPartition originalPartition,
             DefaultResultPartition adaptedPartition) {
 
-        assertEquals(originalPartition.getPartitionId(), adaptedPartition.getId());
-        assertEquals(
-                originalPartition.getIntermediateResult().getId(), adaptedPartition.getResultId());
-        assertEquals(originalPartition.getResultType(), adaptedPartition.getResultType());
-        assertEquals(
-                originalPartition.getProducer().getID(), adaptedPartition.getProducer().getId());
+        assertThat(adaptedPartition.getId()).isEqualTo(originalPartition.getPartitionId());
+        assertThat(adaptedPartition.getResultId())
+                .isEqualTo(originalPartition.getIntermediateResult().getId());
+        assertThat(adaptedPartition.getResultType()).isEqualTo(originalPartition.getResultType());
+        assertThat(adaptedPartition.getProducer().getId())
+                .isEqualTo(originalPartition.getProducer().getID());
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
index f97116e58b2..fc2df697464 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
@@ -27,10 +27,9 @@ import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
 import org.apache.flink.util.IterableUtils;
-import org.apache.flink.util.TestLogger;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.util.Collections;
 import java.util.List;
@@ -38,10 +37,10 @@ import java.util.Map;
 import java.util.function.Supplier;
 
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link DefaultExecutionVertex}. */
-public class DefaultExecutionVertexTest extends TestLogger {
+class DefaultExecutionVertexTest {
 
     private final TestExecutionStateSupplier stateSupplier = new TestExecutionStateSupplier();
 
@@ -51,8 +50,8 @@ public class DefaultExecutionVertexTest extends TestLogger {
 
     private IntermediateResultPartitionID intermediateResultPartitionId;
 
-    @Before
-    public void setUp() throws Exception {
+    @BeforeEach
+    void setUp() throws Exception {
 
         intermediateResultPartitionId = new IntermediateResultPartitionID();
 
@@ -97,15 +96,15 @@ public class DefaultExecutionVertexTest extends TestLogger {
     }
 
     @Test
-    public void testGetExecutionState() {
+    void testGetExecutionState() {
         for (ExecutionState state : ExecutionState.values()) {
             stateSupplier.setExecutionState(state);
-            assertEquals(state, producerVertex.getState());
+            assertThat(producerVertex.getState()).isEqualTo(state);
         }
     }
 
     @Test
-    public void testGetProducedResultPartitions() {
+    void testGetProducedResultPartitions() {
         IntermediateResultPartitionID partitionIds1 =
                 IterableUtils.toStream(producerVertex.getProducedResults())
                         .findAny()
@@ -114,11 +113,11 @@ public class DefaultExecutionVertexTest extends TestLogger {
                                 () ->
                                         new IllegalArgumentException(
                                                 "can not find result partition"));
-        assertEquals(partitionIds1, intermediateResultPartitionId);
+        assertThat(intermediateResultPartitionId).isEqualTo(partitionIds1);
     }
 
     @Test
-    public void testGetConsumedResultPartitions() {
+    void testGetConsumedResultPartitions() {
         IntermediateResultPartitionID partitionIds1 =
                 IterableUtils.toStream(consumerVertex.getConsumedResults())
                         .findAny()
@@ -127,7 +126,7 @@ public class DefaultExecutionVertexTest extends TestLogger {
                                 () ->
                                         new IllegalArgumentException(
                                                 "can not find result partition"));
-        assertEquals(partitionIds1, intermediateResultPartitionId);
+        assertThat(intermediateResultPartitionId).isEqualTo(partitionIds1);
     }
 
     /** A simple implementation of {@link Supplier} for testing. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
index 424dd469212..6d0024626c1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
@@ -24,24 +24,19 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
-import org.apache.flink.util.TestLogger;
 
-import org.junit.Before;
-import org.junit.Test;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
 
 import java.util.HashMap;
 import java.util.Map;
 import java.util.function.Supplier;
 
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.contains;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link DefaultResultPartition}. */
-public class DefaultResultPartitionTest extends TestLogger {
+class DefaultResultPartitionTest {
 
     private static final TestResultPartitionStateSupplier resultPartitionState =
             new TestResultPartitionStateSupplier();
@@ -55,8 +50,8 @@ public class DefaultResultPartitionTest extends TestLogger {
     private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> consumerVertexGroups =
             new HashMap<>();
 
-    @Before
-    public void setUp() {
+    @BeforeEach
+    void setUp() {
         resultPartition =
                 new DefaultResultPartition(
                         resultPartitionId,
@@ -70,24 +65,24 @@ public class DefaultResultPartitionTest extends TestLogger {
     }
 
     @Test
-    public void testGetPartitionState() {
+    void testGetPartitionState() {
         for (ResultPartitionState state : ResultPartitionState.values()) {
             resultPartitionState.setResultPartitionState(state);
-            assertEquals(state, resultPartition.getState());
+            assertThat(resultPartition.getState()).isEqualTo(state);
         }
     }
 
     @Test
-    public void testGetConsumerVertexGroup() {
+    void testGetConsumerVertexGroup() {
 
-        assertFalse(resultPartition.getConsumerVertexGroup().isPresent());
+        assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent();
 
         // test update consumers
         ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
         consumerVertexGroups.put(
                 resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId));
-        assertTrue(resultPartition.getConsumerVertexGroup().isPresent());
-        assertThat(resultPartition.getConsumerVertexGroup().get(), contains(executionVertexId));
+        assertThat(resultPartition.getConsumerVertexGroup()).isPresent();
+        assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId);
     }
 
     /** A test {@link ResultPartitionState} supplier. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index d6508d8b66a..3c9bed61354 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -40,11 +40,10 @@ import org.apache.flink.runtime.scheduler.SchedulerBase;
 import org.apache.flink.runtime.taskmanager.TaskExecutionState;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.Iterator;
@@ -52,24 +51,23 @@ import java.util.List;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.is;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test for {@link AdaptiveBatchScheduler}. */
-public class AdaptiveBatchSchedulerTest extends TestLogger {
+class AdaptiveBatchSchedulerTest {
 
     private static final int SOURCE_PARALLELISM_1 = 6;
     private static final int SOURCE_PARALLELISM_2 = 4;
 
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     private static final ComponentMainThreadExecutor mainThreadExecutor =
             ComponentMainThreadExecutorServiceAdapter.forMainThread();
 
     @Test
-    public void testAdaptiveBatchScheduler() throws Exception {
+    void testAdaptiveBatchScheduler() throws Exception {
         JobGraph jobGraph = createJobGraph(false);
         Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator();
         JobVertex source1 = jobVertexIterator.next();
@@ -82,22 +80,22 @@ public class AdaptiveBatchSchedulerTest extends TestLogger {
         final ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
 
         scheduler.startScheduling();
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source1 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source2 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(10));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(10);
 
         // check that the jobGraph is updated
-        assertThat(sink.getParallelism(), is(10));
+        assertThat(sink.getParallelism()).isEqualTo(10);
     }
 
     @Test
-    public void testDecideParallelismForForwardTarget() throws Exception {
+    void testDecideParallelismForForwardTarget() throws Exception {
         JobGraph jobGraph = createJobGraph(true);
         Iterator<JobVertex> jobVertexIterator = jobGraph.getVertices().iterator();
         JobVertex source1 = jobVertexIterator.next();
@@ -110,18 +108,18 @@ public class AdaptiveBatchSchedulerTest extends TestLogger {
         final ExecutionJobVertex sinkExecutionJobVertex = graph.getJobVertex(sink.getID());
 
         scheduler.startScheduling();
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source1 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source1);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(-1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(-1);
 
         // trigger source2 finished.
         transitionExecutionsState(scheduler, ExecutionState.FINISHED, source2);
-        assertThat(sinkExecutionJobVertex.getParallelism(), is(SOURCE_PARALLELISM_1));
+        assertThat(sinkExecutionJobVertex.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
 
         // check that the jobGraph is updated
-        assertThat(sink.getParallelism(), is(SOURCE_PARALLELISM_1));
+        assertThat(sink.getParallelism()).isEqualTo(SOURCE_PARALLELISM_1);
     }
 
     /** Transit the state of all executions. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
index dbf6cb99062..b45ba3f5ebe 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -29,11 +29,10 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.scheduler.adaptivebatch.AdaptiveBatchScheduler;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.apache.flink.testutils.TestingUtils;
-import org.apache.flink.testutils.executor.TestExecutorResource;
-import org.apache.flink.util.TestLogger;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
 
-import org.junit.ClassRule;
-import org.junit.Test;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Arrays;
 import java.util.HashSet;
@@ -41,14 +40,13 @@ import java.util.Set;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.stream.Collectors;
 
-import static org.hamcrest.Matchers.containsInAnyOrder;
-import static org.junit.Assert.assertEquals;
+import static org.assertj.core.api.Assertions.assertThat;
 
 /** Unit tests for {@link ForwardGroupComputeUtil}. */
-public class ForwardGroupComputeUtilTest extends TestLogger {
-    @ClassRule
-    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE =
-            TestingUtils.defaultExecutorResource();
+class ForwardGroupComputeUtilTest {
+    @RegisterExtension
+    static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorExtension();
 
     /**
      * Tests that the computation of the job graph with isolated vertices works correctly.
@@ -62,7 +60,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testIsolatedVertices() throws Exception {
+    void testIsolatedVertices() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -83,14 +81,14 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testVariousResultPartitionTypesBetweenVertices() throws Exception {
+    void testVariousResultPartitionTypesBetweenVertices() throws Exception {
         testThreeVerticesConnectSequentially(false, true, 1, 2);
         testThreeVerticesConnectSequentially(false, false, 0);
         testThreeVerticesConnectSequentially(true, true, 1, 3);
     }
 
     private void testThreeVerticesConnectSequentially(
-            boolean isForward1, boolean isForward2, int numOfGroups, int... groupSizes)
+            boolean isForward1, boolean isForward2, int numOfGroups, Integer... groupSizes)
             throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
@@ -129,7 +127,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testTwoInputsMergesIntoOne() throws Exception {
+    void testTwoInputsMergesIntoOne() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -164,7 +162,7 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
      * </pre>
      */
     @Test
-    public void testOneInputSplitsIntoTwo() throws Exception {
+    void testOneInputSplitsIntoTwo() throws Exception {
         JobVertex v1 = new JobVertex("v1");
         JobVertex v2 = new JobVertex("v2");
         JobVertex v3 = new JobVertex("v3");
@@ -193,10 +191,11 @@ public class ForwardGroupComputeUtilTest extends TestLogger {
                         .values());
     }
 
-    private static void checkGroupSize(Set<ForwardGroup> groups, int numOfGroups, int... sizes) {
-        assertEquals(numOfGroups, groups.size());
-        containsInAnyOrder(
-                groups.stream().map(ForwardGroup::size).collect(Collectors.toList()), sizes);
+    private static void checkGroupSize(
+            Set<ForwardGroup> groups, int numOfGroups, Integer... sizes) {
+        assertThat(groups.size()).isEqualTo(numOfGroups);
+        assertThat(groups.stream().map(ForwardGroup::size).collect(Collectors.toList()))
+                .contains(sizes);
     }
 
     private static DefaultExecutionGraph createDynamicGraph(JobVertex... vertices)