You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2023/01/30 08:46:13 UTC

[flink] 03/03: [FLINK-15325][coordination] Ignores the input locations of a ConsumePartitionGroup if the corresponding ConsumerVertexGroup is too large

This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 76ebeb93257369c136a9eeabd5bacb71a9699968
Author: Zhu Zhu <re...@gmail.com>
AuthorDate: Thu Jan 19 17:13:50 2023 +0800

    [FLINK-15325][coordination] Ignores the input locations of a ConsumePartitionGroup if the corresponding ConsumerVertexGroup is too large
    
    This closes #21743.
---
 .../runtime/executiongraph/ExecutionGraph.java     |  10 ++
 .../runtime/executiongraph/ExecutionVertex.java    |   2 -
 .../AvailableInputsLocationsRetriever.java         |  12 +-
 .../DefaultPreferredLocationsRetriever.java        |  27 +++-
 .../flink/runtime/scheduler/DefaultScheduler.java  |  12 +-
 ...tionGraphToInputsLocationsRetrieverAdapter.java |  38 +++---
 .../scheduler/InputsLocationsRetriever.java        |  26 ++--
 .../AvailableInputsLocationsRetrieverTest.java     |  20 ++-
 .../DefaultPreferredLocationsRetrieverTest.java    | 144 ++++++++++++++-------
 ...DefaultSyncPreferredLocationsRetrieverTest.java |  31 +++--
 ...GraphToInputsLocationsRetrieverAdapterTest.java |  72 ++++++++---
 .../scheduler/TestingInputsLocationsRetriever.java |  84 +++++++++---
 .../adaptive/StateTrackingMockExecutionGraph.java  |   7 +
 13 files changed, 339 insertions(+), 146 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
index 2f90291e75c..cc1f5fa33c6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionGraph.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.MasterTriggerRestoreHook;
 import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
 import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
@@ -124,6 +125,15 @@ public interface ExecutionGraph extends AccessExecutionGraph {
 
     Map<IntermediateDataSetID, IntermediateResult> getAllIntermediateResults();
 
+    /**
+     * Gets the intermediate result partition by the given partition ID, or throw an exception if
+     * the partition is not found.
+     *
+     * @param id of the intermediate result partition
+     * @return intermediate result partition
+     */
+    IntermediateResultPartition getResultPartitionOrThrow(final IntermediateResultPartitionID id);
+
     /**
      * Merges all accumulator results from the tasks previously executed in the Executions.
      *
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
index 5ff589669b4..28f1d745e37 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/ExecutionVertex.java
@@ -62,8 +62,6 @@ public class ExecutionVertex
 
     public static final long NUM_BYTES_UNKNOWN = -1;
 
-    public static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8;
-
     // --------------------------------------------------------------------------------------------
 
     final ExecutionJobVertex jobVertex;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java
index 4008b1cef01..2c2a0f44741 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetriever.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler;
 
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 
@@ -34,9 +35,16 @@ class AvailableInputsLocationsRetriever implements InputsLocationsRetriever {
     }
 
     @Override
-    public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers(
+    public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups(
             ExecutionVertexID executionVertexId) {
-        return inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId);
+        return inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId);
+    }
+
+    @Override
+    public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup(
+            ConsumedPartitionGroup consumedPartitionGroup) {
+        return inputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                consumedPartitionGroup);
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java
index 0f7fdeb1316..f3b71366682 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetriever.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.scheduler;
 
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.util.concurrent.FutureUtils;
@@ -30,7 +31,6 @@ import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 
-import static org.apache.flink.runtime.executiongraph.ExecutionVertex.MAX_DISTINCT_LOCATIONS_TO_CONSIDER;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /**
@@ -39,6 +39,10 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  */
 public class DefaultPreferredLocationsRetriever implements PreferredLocationsRetriever {
 
+    static final int MAX_DISTINCT_LOCATIONS_TO_CONSIDER = 8;
+
+    static final int MAX_DISTINCT_CONSUMERS_TO_CONSIDER = 8;
+
     private final StateLocationRetriever stateLocationRetriever;
 
     private final InputsLocationsRetriever inputsLocationsRetriever;
@@ -84,11 +88,24 @@ public class DefaultPreferredLocationsRetriever implements PreferredLocationsRet
         CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
                 CompletableFuture.completedFuture(Collections.emptyList());
 
-        final Collection<Collection<ExecutionVertexID>> allProducers =
-                inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId);
-        for (Collection<ExecutionVertexID> producers : allProducers) {
+        final Collection<ConsumedPartitionGroup> consumedPartitionGroups =
+                inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId);
+        for (ConsumedPartitionGroup consumedPartitionGroup : consumedPartitionGroups) {
+            // Ignore the location of a consumed partition group if it has too many distinct
+            // consumers compared to the consumed partition group size. This is to avoid tasks
+            // unevenly distributed on nodes when running batch jobs or running jobs in
+            // session/standalone mode.
+            if ((double) consumedPartitionGroup.getConsumerVertexGroup().size()
+                            / consumedPartitionGroup.size()
+                    > MAX_DISTINCT_CONSUMERS_TO_CONSIDER) {
+                continue;
+            }
+
             final Collection<CompletableFuture<TaskManagerLocation>> locationsFutures =
-                    getInputLocationFutures(producersToIgnore, producers);
+                    getInputLocationFutures(
+                            producersToIgnore,
+                            inputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                                    consumedPartitionGroup));
 
             preferredLocations = combineLocations(preferredLocations, locationsFutures);
         }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
index 2136479a2c7..ecd2a2467cb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
@@ -44,6 +44,7 @@ import org.apache.flink.runtime.jobmanager.scheduler.CoLocationGroup;
 import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.metrics.groups.JobManagerJobMetricGroup;
 import org.apache.flink.runtime.scheduler.exceptionhistory.FailureHandlingResultSnapshot;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategy;
 import org.apache.flink.runtime.scheduler.strategy.SchedulingStrategyFactory;
@@ -511,9 +512,16 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
         }
 
         @Override
-        public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers(
+        public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups(
                 ExecutionVertexID executionVertexId) {
-            return inputsLocationsRetriever.getConsumedResultPartitionsProducers(executionVertexId);
+            return inputsLocationsRetriever.getConsumedPartitionGroups(executionVertexId);
+        }
+
+        @Override
+        public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup(
+                ConsumedPartitionGroup consumedPartitionGroup) {
+            return inputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                    consumedPartitionGroup);
         }
 
         @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java
index 0a6786c8ac6..35f3da6c868 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapter.java
@@ -22,17 +22,15 @@ import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
-import org.apache.flink.runtime.executiongraph.InternalExecutionGraphAccessor;
-import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
+import org.apache.flink.util.IterableUtils;
 
-import java.util.ArrayList;
 import java.util.Collection;
-import java.util.List;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
@@ -47,26 +45,22 @@ public class ExecutionGraphToInputsLocationsRetrieverAdapter implements InputsLo
     }
 
     @Override
-    public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers(
+    public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups(
             ExecutionVertexID executionVertexId) {
-        ExecutionVertex ev = getExecutionVertex(executionVertexId);
-
-        InternalExecutionGraphAccessor executionGraphAccessor = ev.getExecutionGraphAccessor();
+        return getExecutionVertex(executionVertexId).getAllConsumedPartitionGroups();
+    }
 
-        List<Collection<ExecutionVertexID>> resultPartitionProducers =
-                new ArrayList<>(ev.getNumberOfInputs());
-        for (ConsumedPartitionGroup consumedPartitions : ev.getAllConsumedPartitionGroups()) {
-            List<ExecutionVertexID> producers = new ArrayList<>(consumedPartitions.size());
-            for (IntermediateResultPartitionID consumedPartitionId : consumedPartitions) {
-                ExecutionVertex producer =
-                        executionGraphAccessor
-                                .getResultPartitionOrThrow(consumedPartitionId)
-                                .getProducer();
-                producers.add(producer.getID());
-            }
-            resultPartitionProducers.add(producers);
-        }
-        return resultPartitionProducers;
+    @Override
+    public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup(
+            ConsumedPartitionGroup consumedPartitionGroup) {
+        return IterableUtils.toStream(consumedPartitionGroup)
+                .map(
+                        partition ->
+                                executionGraph
+                                        .getResultPartitionOrThrow(partition)
+                                        .getProducer()
+                                        .getID())
+                .collect(Collectors.toList());
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java
index ea143bb5e59..0c49f4bf5df 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/InputsLocationsRetriever.java
@@ -18,7 +18,8 @@
 
 package org.apache.flink.runtime.scheduler;
 
-import org.apache.flink.runtime.executiongraph.Execution;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 
@@ -26,22 +27,31 @@ import java.util.Collection;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 
-/** Component to retrieve the inputs locations of a {@link Execution}. */
+/** Component to retrieve the inputs locations of an {@link ExecutionVertex}. */
 public interface InputsLocationsRetriever {
 
     /**
-     * Get the producers of the result partitions consumed by an execution.
+     * Get the consumed result partition groups of an execution vertex.
      *
-     * @param executionVertexId identifies the execution
-     * @return the producers of the result partitions group by job vertex id
+     * @param executionVertexId identifies the execution vertex
+     * @return the consumed result partition groups
      */
-    Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers(
+    Collection<ConsumedPartitionGroup> getConsumedPartitionGroups(
             ExecutionVertexID executionVertexId);
 
     /**
-     * Get the task manager location future for an execution.
+     * Get the producer execution vertices of a consumed result partition group.
      *
-     * @param executionVertexId identifying the execution
+     * @param consumedPartitionGroup the consumed result partition group
+     * @return the ids of producer execution vertices
+     */
+    Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup(
+            ConsumedPartitionGroup consumedPartitionGroup);
+
+    /**
+     * Get the task manager location future for an execution vertex.
+     *
+     * @param executionVertexId identifying the execution vertex
      * @return the task manager location future
      */
     Optional<CompletableFuture<TaskManagerLocation>> getTaskManagerLocation(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java
index d2d5be5b18e..0ebcffdf570 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/AvailableInputsLocationsRetrieverTest.java
@@ -18,8 +18,11 @@
 
 package org.apache.flink.runtime.scheduler;
 
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 
+import org.apache.flink.shaded.guava30.com.google.common.collect.Iterables;
+
 import org.junit.jupiter.api.Test;
 
 import java.util.Collection;
@@ -68,15 +71,20 @@ class AvailableInputsLocationsRetrieverTest {
     }
 
     @Test
-    void testConsumedResultPartitionsProducers() {
+    void testGetConsumedPartitionGroupAndProducers() {
         TestingInputsLocationsRetriever originalLocationRetriever = getOriginalLocationRetriever();
         InputsLocationsRetriever availableInputsLocationsRetriever =
                 new AvailableInputsLocationsRetriever(originalLocationRetriever);
-        Collection<Collection<ExecutionVertexID>> producers =
-                availableInputsLocationsRetriever.getConsumedResultPartitionsProducers(EV2);
-        assertThat(producers).hasSize(1);
-        Collection<ExecutionVertexID> resultProducers = producers.iterator().next();
-        assertThat(resultProducers).containsExactly(EV1);
+
+        ConsumedPartitionGroup consumedPartitionGroup =
+                Iterables.getOnlyElement(
+                        (availableInputsLocationsRetriever.getConsumedPartitionGroups(EV2)));
+        assertThat(consumedPartitionGroup).hasSize(1);
+
+        Collection<ExecutionVertexID> producers =
+                availableInputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                        consumedPartitionGroup);
+        assertThat(producers).containsExactly(EV1);
     }
 
     private static TestingInputsLocationsRetriever getOriginalLocationRetriever() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java
index 64944dcea00..6fda52dc173 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultPreferredLocationsRetrieverTest.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.scheduler;
 
-import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation;
@@ -27,12 +26,18 @@ import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.junit.jupiter.api.Test;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
 import java.util.Optional;
+import java.util.Set;
 import java.util.concurrent.CompletableFuture;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
 
+import static org.apache.flink.runtime.scheduler.DefaultPreferredLocationsRetriever.MAX_DISTINCT_CONSUMERS_TO_CONSIDER;
+import static org.apache.flink.runtime.scheduler.DefaultPreferredLocationsRetriever.MAX_DISTINCT_LOCATIONS_TO_CONSIDER;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests {@link DefaultPreferredLocationsRetriever}. */
@@ -65,36 +70,38 @@ class DefaultPreferredLocationsRetrieverTest {
     }
 
     @Test
-    void testInputLocationsIgnoresEdgeOfTooManyLocations() {
-        final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
-                new TestingInputsLocationsRetriever.Builder();
-
-        final ExecutionVertexID consumerId = new ExecutionVertexID(new JobVertexID(), 0);
-
-        final int producerParallelism = ExecutionVertex.MAX_DISTINCT_LOCATIONS_TO_CONSIDER + 1;
-        final List<ExecutionVertexID> producerIds = new ArrayList<>(producerParallelism);
-        final JobVertexID producerJobVertexId = new JobVertexID();
-        for (int i = 0; i < producerParallelism; i++) {
-            final ExecutionVertexID producerId = new ExecutionVertexID(producerJobVertexId, i);
-            locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId);
-            producerIds.add(producerId);
+    void testInputLocations() {
+        {
+            final List<TaskManagerLocation> producerLocations =
+                    Collections.singletonList(new LocalTaskManagerLocation());
+            testInputLocationsInternal(
+                    1,
+                    MAX_DISTINCT_CONSUMERS_TO_CONSIDER,
+                    producerLocations,
+                    producerLocations,
+                    Collections.emptySet());
         }
-
-        final TestingInputsLocationsRetriever inputsLocationsRetriever =
-                locationRetrieverBuilder.build();
-
-        for (int i = 0; i < producerParallelism; i++) {
-            inputsLocationsRetriever.markScheduled(producerIds.get(i));
+        {
+            final List<TaskManagerLocation> producerLocations =
+                    Arrays.asList(new LocalTaskManagerLocation(), new LocalTaskManagerLocation());
+            testInputLocationsInternal(
+                    2,
+                    MAX_DISTINCT_CONSUMERS_TO_CONSIDER * 2,
+                    producerLocations,
+                    producerLocations,
+                    Collections.emptySet());
         }
+    }
 
-        final PreferredLocationsRetriever locationsRetriever =
-                new DefaultPreferredLocationsRetriever(
-                        id -> Optional.empty(), inputsLocationsRetriever);
-
-        final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
-                locationsRetriever.getPreferredLocations(consumerId, Collections.emptySet());
+    @Test
+    void testInputLocationsIgnoresEdgeOfTooManyProducers() {
+        testNoPreferredInputLocationsInternal(MAX_DISTINCT_LOCATIONS_TO_CONSIDER + 1, 1);
+    }
 
-        assertThat(preferredLocations.getNow(null)).isEmpty();
+    @Test
+    void testInputLocationsIgnoresEdgeOfTooManyConsumers() {
+        testNoPreferredInputLocationsInternal(1, MAX_DISTINCT_CONSUMERS_TO_CONSIDER + 1);
+        testNoPreferredInputLocationsInternal(2, MAX_DISTINCT_CONSUMERS_TO_CONSIDER * 2 + 1);
     }
 
     @Test
@@ -110,8 +117,8 @@ class DefaultPreferredLocationsRetrieverTest {
         for (int i = 0; i < parallelism1; i++) {
             final ExecutionVertexID producerId = new ExecutionVertexID(jobVertexId1, i);
             producers1.add(producerId);
-            locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId);
         }
+        locationRetrieverBuilder.connectConsumerToProducers(consumerId, producers1);
 
         final JobVertexID jobVertexId2 = new JobVertexID();
         int parallelism2 = 5;
@@ -119,8 +126,8 @@ class DefaultPreferredLocationsRetrieverTest {
         for (int i = 0; i < parallelism2; i++) {
             final ExecutionVertexID producerId = new ExecutionVertexID(jobVertexId2, i);
             producers2.add(producerId);
-            locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId);
         }
+        locationRetrieverBuilder.connectConsumerToProducers(consumerId, producers2);
 
         final TestingInputsLocationsRetriever inputsLocationsRetriever =
                 locationRetrieverBuilder.build();
@@ -152,40 +159,83 @@ class DefaultPreferredLocationsRetrieverTest {
 
     @Test
     void testInputLocationsIgnoresExcludedProducers() {
-        final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
-                new TestingInputsLocationsRetriever.Builder();
+        final List<TaskManagerLocation> producerLocations =
+                Arrays.asList(new LocalTaskManagerLocation(), new LocalTaskManagerLocation());
+        final Set<Integer> producersToIgnore = Collections.singleton(0);
+        testInputLocationsInternal(
+                2, 1, producerLocations, producerLocations.subList(1, 2), producersToIgnore);
+    }
 
-        final ExecutionVertexID consumerId = new ExecutionVertexID(new JobVertexID(), 0);
+    private void testNoPreferredInputLocationsInternal(
+            final int producerParallelism, final int consumerParallelism) {
+        testInputLocationsInternal(
+                producerParallelism,
+                consumerParallelism,
+                Collections.emptyList(),
+                Collections.emptyList(),
+                Collections.emptySet());
+    }
+
+    private void testInputLocationsInternal(
+            final int producerParallelism,
+            final int consumerParallelism,
+            final List<TaskManagerLocation> producerLocations,
+            final List<TaskManagerLocation> expectedPreferredLocations,
+            final Set<Integer> indicesOfProducersToIgnore) {
 
         final JobVertexID producerJobVertexId = new JobVertexID();
+        final List<ExecutionVertexID> producerIds =
+                IntStream.range(0, producerParallelism)
+                        .mapToObj(i -> new ExecutionVertexID(producerJobVertexId, i))
+                        .collect(Collectors.toList());
 
-        final ExecutionVertexID producerId1 = new ExecutionVertexID(producerJobVertexId, 0);
-        locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId1);
+        final JobVertexID consumerJobVertexId = new JobVertexID();
+        final List<ExecutionVertexID> consumerIds =
+                IntStream.range(0, consumerParallelism)
+                        .mapToObj(i -> new ExecutionVertexID(consumerJobVertexId, i))
+                        .collect(Collectors.toList());
 
-        final ExecutionVertexID producerId2 = new ExecutionVertexID(producerJobVertexId, 1);
-        locationRetrieverBuilder.connectConsumerToProducer(consumerId, producerId2);
+        final TestingInputsLocationsRetriever.Builder locationRetrieverBuilder =
+                new TestingInputsLocationsRetriever.Builder();
+        locationRetrieverBuilder.connectConsumersToProducers(consumerIds, producerIds);
 
         final TestingInputsLocationsRetriever inputsLocationsRetriever =
                 locationRetrieverBuilder.build();
+        for (int i = 0; i < producerParallelism; i++) {
+            TaskManagerLocation producerLocation;
+            if (producerLocations.isEmpty()) {
+                // generate a random location if not specified
+                producerLocation = new LocalTaskManagerLocation();
+            } else {
+                producerLocation = producerLocations.get(i);
+            }
+            inputsLocationsRetriever.assignTaskManagerLocation(
+                    producerIds.get(i), producerLocation);
+        }
 
-        inputsLocationsRetriever.markScheduled(producerId1);
-        inputsLocationsRetriever.markScheduled(producerId2);
+        checkInputLocations(
+                consumerIds.get(0),
+                inputsLocationsRetriever,
+                expectedPreferredLocations,
+                indicesOfProducersToIgnore.stream()
+                        .map(index -> new ExecutionVertexID(producerJobVertexId, index))
+                        .collect(Collectors.toSet()));
+    }
 
-        inputsLocationsRetriever.assignTaskManagerLocation(producerId1);
-        inputsLocationsRetriever.assignTaskManagerLocation(producerId2);
+    private void checkInputLocations(
+            final ExecutionVertexID consumerId,
+            final TestingInputsLocationsRetriever inputsLocationsRetriever,
+            final List<TaskManagerLocation> expectedPreferredLocations,
+            final Set<ExecutionVertexID> producersToIgnore) {
 
         final PreferredLocationsRetriever locationsRetriever =
                 new DefaultPreferredLocationsRetriever(
                         id -> Optional.empty(), inputsLocationsRetriever);
 
         final CompletableFuture<Collection<TaskManagerLocation>> preferredLocations =
-                locationsRetriever.getPreferredLocations(
-                        consumerId, Collections.singleton(producerId1));
+                locationsRetriever.getPreferredLocations(consumerId, producersToIgnore);
 
-        assertThat(preferredLocations.getNow(null)).hasSize(1);
-
-        final TaskManagerLocation producerLocation2 =
-                inputsLocationsRetriever.getTaskManagerLocation(producerId2).get().getNow(null);
-        assertThat(preferredLocations.getNow(null)).containsExactly(producerLocation2);
+        assertThat(preferredLocations.getNow(null))
+                .containsExactlyInAnyOrderElementsOf(expectedPreferredLocations);
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java
index 43a61ae37fb..db7469d1001 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/DefaultSyncPreferredLocationsRetrieverTest.java
@@ -18,49 +18,48 @@
 
 package org.apache.flink.runtime.scheduler;
 
+import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 
 import org.junit.jupiter.api.Test;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Optional;
 
-import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createRandomExecutionVertexId;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Tests for {@link DefaultSyncPreferredLocationsRetriever}. */
 class DefaultSyncPreferredLocationsRetrieverTest {
-    private static final ExecutionVertexID EV1 = createRandomExecutionVertexId();
-    private static final ExecutionVertexID EV2 = createRandomExecutionVertexId();
-    private static final ExecutionVertexID EV3 = createRandomExecutionVertexId();
-    private static final ExecutionVertexID EV4 = createRandomExecutionVertexId();
-    private static final ExecutionVertexID EV5 = createRandomExecutionVertexId();
+    private static final JobVertexID JV1 = new JobVertexID();
+    private static final ExecutionVertexID EV11 = new ExecutionVertexID(JV1, 0);
+    private static final ExecutionVertexID EV12 = new ExecutionVertexID(JV1, 1);
+    private static final ExecutionVertexID EV13 = new ExecutionVertexID(JV1, 2);
+    private static final ExecutionVertexID EV14 = new ExecutionVertexID(JV1, 3);
+    private static final ExecutionVertexID EV21 = new ExecutionVertexID(new JobVertexID(), 0);
 
     @Test
     void testAvailableInputLocationRetrieval() {
         TestingInputsLocationsRetriever originalLocationRetriever =
                 new TestingInputsLocationsRetriever.Builder()
-                        .connectConsumerToProducer(EV5, EV1)
-                        .connectConsumerToProducer(EV5, EV2)
-                        .connectConsumerToProducer(EV5, EV3)
-                        .connectConsumerToProducer(EV5, EV4)
+                        .connectConsumerToProducers(EV21, Arrays.asList(EV11, EV12, EV13, EV14))
                         .build();
 
-        originalLocationRetriever.assignTaskManagerLocation(EV1);
-        originalLocationRetriever.markScheduled(EV2);
-        originalLocationRetriever.failTaskManagerLocation(EV3, new Throwable());
-        originalLocationRetriever.cancelTaskManagerLocation(EV4);
+        originalLocationRetriever.assignTaskManagerLocation(EV11);
+        originalLocationRetriever.markScheduled(EV12);
+        originalLocationRetriever.failTaskManagerLocation(EV13, new Throwable());
+        originalLocationRetriever.cancelTaskManagerLocation(EV14);
 
         SyncPreferredLocationsRetriever locationsRetriever =
                 new DefaultSyncPreferredLocationsRetriever(
                         executionVertexId -> Optional.empty(), originalLocationRetriever);
 
         Collection<TaskManagerLocation> preferredLocations =
-                locationsRetriever.getPreferredLocations(EV5, Collections.emptySet());
+                locationsRetriever.getPreferredLocations(EV21, Collections.emptySet());
         TaskManagerLocation expectedLocation =
-                originalLocationRetriever.getTaskManagerLocation(EV1).get().join();
+                originalLocationRetriever.getTaskManagerLocation(EV11).get().join();
 
         assertThat(preferredLocations).containsExactly(expectedLocation);
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java
index f01fb65248c..066cf9a0035 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/ExecutionGraphToInputsLocationsRetrieverAdapterTest.java
@@ -24,23 +24,27 @@ import org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmaster.TestingLogicalSlot;
 import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
+import org.apache.flink.util.IterableUtils;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.RegisterExtension;
 
 import java.util.Collection;
-import java.util.Collections;
 import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -52,16 +56,23 @@ class ExecutionGraphToInputsLocationsRetrieverAdapterTest {
     static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_EXTENSION =
             TestingUtils.defaultExecutorExtension();
 
-    /** Tests that can get the producers of consumed result partitions. */
     @Test
-    void testGetConsumedResultPartitionsProducers() throws Exception {
+    void testGetConsumedPartitionGroupsAndProducers() throws Exception {
         final JobVertex producer1 = ExecutionGraphTestUtils.createNoOpVertex(1);
         final JobVertex producer2 = ExecutionGraphTestUtils.createNoOpVertex(1);
         final JobVertex consumer = ExecutionGraphTestUtils.createNoOpVertex(1);
-        consumer.connectNewDataSetAsInput(
-                producer1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
-        consumer.connectNewDataSetAsInput(
-                producer2, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        final IntermediateDataSet dataSet1 =
+                consumer.connectNewDataSetAsInput(
+                                producer1,
+                                DistributionPattern.ALL_TO_ALL,
+                                ResultPartitionType.PIPELINED)
+                        .getSource();
+        final IntermediateDataSet dataSet2 =
+                consumer.connectNewDataSetAsInput(
+                                producer2,
+                                DistributionPattern.ALL_TO_ALL,
+                                ResultPartitionType.PIPELINED)
+                        .getSource();
 
         final ExecutionGraph eg =
                 ExecutionGraphTestUtils.createExecutionGraph(
@@ -73,20 +84,39 @@ class ExecutionGraphToInputsLocationsRetrieverAdapterTest {
         ExecutionVertexID evIdOfProducer2 = new ExecutionVertexID(producer2.getID(), 0);
         ExecutionVertexID evIdOfConsumer = new ExecutionVertexID(consumer.getID(), 0);
 
-        Collection<Collection<ExecutionVertexID>> producersOfProducer1 =
-                inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfProducer1);
-        Collection<Collection<ExecutionVertexID>> producersOfProducer2 =
-                inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfProducer2);
-        Collection<Collection<ExecutionVertexID>> producersOfConsumer =
-                inputsLocationsRetriever.getConsumedResultPartitionsProducers(evIdOfConsumer);
-
-        assertThat(producersOfProducer1).isEmpty();
-        assertThat(producersOfProducer2).isEmpty();
-        assertThat(producersOfConsumer).hasSize(2);
-        assertThat(producersOfConsumer)
-                .containsExactlyInAnyOrder(
-                        Collections.singletonList(evIdOfProducer1),
-                        Collections.singletonList(evIdOfProducer2));
+        Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfProducer1 =
+                inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfProducer1);
+        Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfProducer2 =
+                inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfProducer2);
+        Collection<ConsumedPartitionGroup> consumedPartitionGroupsOfConsumer =
+                inputsLocationsRetriever.getConsumedPartitionGroups(evIdOfConsumer);
+
+        IntermediateResultPartitionID partitionId1 =
+                new IntermediateResultPartitionID(dataSet1.getId(), 0);
+        IntermediateResultPartitionID partitionId2 =
+                new IntermediateResultPartitionID(dataSet2.getId(), 0);
+        assertThat(consumedPartitionGroupsOfProducer1).isEmpty();
+        assertThat(consumedPartitionGroupsOfProducer2).isEmpty();
+        assertThat(consumedPartitionGroupsOfConsumer).hasSize(2);
+        assertThat(
+                        consumedPartitionGroupsOfConsumer.stream()
+                                .flatMap(IterableUtils::toStream)
+                                .collect(Collectors.toSet()))
+                .containsExactlyInAnyOrder(partitionId1, partitionId2);
+
+        for (ConsumedPartitionGroup consumedPartitionGroup : consumedPartitionGroupsOfConsumer) {
+            if (consumedPartitionGroup.getFirst().equals(partitionId1)) {
+                assertThat(
+                                inputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                                        consumedPartitionGroup))
+                        .containsExactly(evIdOfProducer1);
+            } else {
+                assertThat(
+                                inputsLocationsRetriever.getProducersOfConsumedPartitionGroup(
+                                        consumedPartitionGroup))
+                        .containsExactly(evIdOfProducer2);
+            }
+        }
     }
 
     /** Tests that it will get empty task manager location if vertex is not scheduled. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java
index 977a53a2a05..139bbc8b856 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/TestingInputsLocationsRetriever.java
@@ -18,10 +18,15 @@
 
 package org.apache.flink.runtime.scheduler;
 
-import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.scheduler.strategy.TestingSchedulingTopology;
 import org.apache.flink.runtime.taskmanager.LocalTaskManagerLocation;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
+import org.apache.flink.util.IterableUtils;
 
 import java.util.ArrayList;
 import java.util.Collection;
@@ -33,28 +38,40 @@ import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.runtime.scheduler.strategy.TestingSchedulingTopology.connectConsumersToProducersById;
+
 /** A simple inputs locations retriever for testing purposes. */
 class TestingInputsLocationsRetriever implements InputsLocationsRetriever {
 
-    private final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer;
+    private final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>>
+            vertexToConsumedPartitionGroups;
+
+    private final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer;
 
     private final Map<ExecutionVertexID, CompletableFuture<TaskManagerLocation>>
             taskManagerLocationsByVertex = new HashMap<>();
 
     TestingInputsLocationsRetriever(
-            final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer) {
-        this.producersByConsumer = new HashMap<>(producersByConsumer);
+            final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>>
+                    vertexToConsumedPartitionGroups,
+            final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer) {
+
+        this.vertexToConsumedPartitionGroups = vertexToConsumedPartitionGroups;
+        this.partitionToProducer = partitionToProducer;
     }
 
     @Override
-    public Collection<Collection<ExecutionVertexID>> getConsumedResultPartitionsProducers(
+    public Collection<ConsumedPartitionGroup> getConsumedPartitionGroups(
             final ExecutionVertexID executionVertexId) {
-        final Map<JobVertexID, List<ExecutionVertexID>> executionVerticesByJobVertex =
-                producersByConsumer.getOrDefault(executionVertexId, Collections.emptyList())
-                        .stream()
-                        .collect(Collectors.groupingBy(ExecutionVertexID::getJobVertexId));
+        return vertexToConsumedPartitionGroups.get(executionVertexId);
+    }
 
-        return new ArrayList<>(executionVerticesByJobVertex.values());
+    @Override
+    public Collection<ExecutionVertexID> getProducersOfConsumedPartitionGroup(
+            ConsumedPartitionGroup consumedPartitionGroup) {
+        return IterableUtils.toStream(consumedPartitionGroup)
+                .map(partitionToProducer::get)
+                .collect(Collectors.toList());
     }
 
     @Override
@@ -68,13 +85,18 @@ class TestingInputsLocationsRetriever implements InputsLocationsRetriever {
     }
 
     public void assignTaskManagerLocation(final ExecutionVertexID executionVertexId) {
+        assignTaskManagerLocation(executionVertexId, new LocalTaskManagerLocation());
+    }
+
+    public void assignTaskManagerLocation(
+            final ExecutionVertexID executionVertexId, TaskManagerLocation location) {
         taskManagerLocationsByVertex.compute(
                 executionVertexId,
                 (key, future) -> {
                     if (future == null) {
-                        return CompletableFuture.completedFuture(new LocalTaskManagerLocation());
+                        return CompletableFuture.completedFuture(location);
                     }
-                    future.complete(new LocalTaskManagerLocation());
+                    future.complete(location);
                     return future;
                 });
     }
@@ -107,17 +129,49 @@ class TestingInputsLocationsRetriever implements InputsLocationsRetriever {
 
     static class Builder {
 
-        private final Map<ExecutionVertexID, List<ExecutionVertexID>> producersByConsumer =
+        private final Map<ExecutionVertexID, Collection<ConsumedPartitionGroup>>
+                vertexToConsumedPartitionGroups = new HashMap<>();
+
+        private final Map<IntermediateResultPartitionID, ExecutionVertexID> partitionToProducer =
                 new HashMap<>();
 
         public Builder connectConsumerToProducer(
                 final ExecutionVertexID consumer, final ExecutionVertexID producer) {
-            producersByConsumer.computeIfAbsent(consumer, (key) -> new ArrayList<>()).add(producer);
+            return connectConsumerToProducers(consumer, Collections.singletonList(producer));
+        }
+
+        public Builder connectConsumerToProducers(
+                final ExecutionVertexID consumer, final List<ExecutionVertexID> producers) {
+            return connectConsumersToProducers(Collections.singletonList(consumer), producers);
+        }
+
+        public Builder connectConsumersToProducers(
+                final List<ExecutionVertexID> consumers, final List<ExecutionVertexID> producers) {
+            TestingSchedulingTopology.ConnectionResult connectionResult =
+                    connectConsumersToProducersById(
+                            consumers,
+                            producers,
+                            new IntermediateDataSetID(),
+                            ResultPartitionType.PIPELINED);
+
+            for (int i = 0; i < producers.size(); i++) {
+                partitionToProducer.put(
+                        connectionResult.getResultPartitions().get(i), producers.get(i));
+            }
+
+            for (ExecutionVertexID consumer : consumers) {
+                final Collection<ConsumedPartitionGroup> consumedPartitionGroups =
+                        vertexToConsumedPartitionGroups.computeIfAbsent(
+                                consumer, ignore -> new ArrayList<>());
+                consumedPartitionGroups.add(connectionResult.getConsumedPartitionGroup());
+            }
+
             return this;
         }
 
         public TestingInputsLocationsRetriever build() {
-            return new TestingInputsLocationsRetriever(producersByConsumer);
+            return new TestingInputsLocationsRetriever(
+                    vertexToConsumedPartitionGroups, partitionToProducer);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
index 639c846cbec..c87e3473ab3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptive/StateTrackingMockExecutionGraph.java
@@ -43,11 +43,13 @@ import org.apache.flink.runtime.executiongraph.ExecutionGraph;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.executiongraph.JobStatusListener;
 import org.apache.flink.runtime.executiongraph.JobVertexInputInfo;
 import org.apache.flink.runtime.executiongraph.TaskExecutionStateTransition;
 import org.apache.flink.runtime.executiongraph.failover.flip1.ResultPartitionAvailabilityChecker;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
@@ -314,6 +316,11 @@ class StateTrackingMockExecutionGraph implements ExecutionGraph {
         throw new UnsupportedOperationException();
     }
 
+    @Override
+    public IntermediateResultPartition getResultPartitionOrThrow(IntermediateResultPartitionID id) {
+        throw new UnsupportedOperationException();
+    }
+
     @Override
     public Map<String, OptionalFailure<Accumulator<?, ?>>> aggregateUserAccumulators() {
         throw new UnsupportedOperationException();