You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2023/01/30 08:46:12 UTC

[flink] 02/03: [FLINK-15325][coordination] Set the ConsumedPartitionGroup/ConsumerVertexGroup to its corresponding ConsumerVertexGroup/ConsumedPartitionGroup

This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit c0bfb0b04bb38411c390813fea1ff93f6638f409
Author: Zhu Zhu <re...@gmail.com>
AuthorDate: Mon Jan 30 12:02:13 2023 +0800

    [FLINK-15325][coordination] Set the ConsumedPartitionGroup/ConsumerVertexGroup to its corresponding ConsumerVertexGroup/ConsumedPartitionGroup
---
 .../executiongraph/EdgeManagerBuildUtil.java       |   3 +
 .../scheduler/strategy/ConsumedPartitionGroup.java |  20 +-
 .../scheduler/strategy/ConsumerVertexGroup.java    |  21 +-
 .../executiongraph/EdgeManagerBuildUtilTest.java   |  64 ++++--
 .../strategy/TestingSchedulingExecutionVertex.java |  61 +-----
 .../strategy/TestingSchedulingResultPartition.java |  14 +-
 .../strategy/TestingSchedulingTopology.java        | 236 ++++++++++++++-------
 7 files changed, 244 insertions(+), 175 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index a3613eba27c..889d7a39ae1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -174,6 +174,9 @@ public class EdgeManagerBuildUtil {
         for (IntermediateResultPartition partition : partitions) {
             partition.addConsumers(consumerVertexGroup);
         }
+
+        consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup);
+        consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup);
     }
 
     private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
index c3d32fa99fb..e4440d7575c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
@@ -23,14 +23,21 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.util.Preconditions;
 
+import javax.annotation.Nullable;
+
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.atomic.AtomicInteger;
 
 import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
-/** Group of consumed {@link IntermediateResultPartitionID}s. */
+/**
+ * Group of consumed {@link IntermediateResultPartitionID}s. One such a group corresponds to one
+ * {@link ConsumerVertexGroup}.
+ */
 public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartitionID> {
 
     private final List<IntermediateResultPartitionID> resultPartitions;
@@ -44,6 +51,8 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
     /** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */
     private final int numConsumers;
 
+    @Nullable private ConsumerVertexGroup consumerVertexGroup;
+
     private ConsumedPartitionGroup(
             int numConsumers,
             List<IntermediateResultPartitionID> resultPartitions,
@@ -130,4 +139,13 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
     public ResultPartitionType getResultPartitionType() {
         return resultPartitionType;
     }
+
+    public ConsumerVertexGroup getConsumerVertexGroup() {
+        return checkNotNull(consumerVertexGroup, "ConsumerVertexGroup is not properly set.");
+    }
+
+    public void setConsumerVertexGroup(ConsumerVertexGroup consumerVertexGroup) {
+        checkState(this.consumerVertexGroup == null);
+        this.consumerVertexGroup = checkNotNull(consumerVertexGroup);
+    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
index fb8b3f1951c..9939206b7d9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
@@ -20,16 +20,26 @@ package org.apache.flink.runtime.scheduler.strategy;
 
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 
+import javax.annotation.Nullable;
+
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 
-/** Group of consumer {@link ExecutionVertexID}s. */
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Group of consumer {@link ExecutionVertexID}s. One such a group corresponds to one {@link
+ * ConsumedPartitionGroup}.
+ */
 public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> {
     private final List<ExecutionVertexID> vertices;
 
     private final ResultPartitionType resultPartitionType;
 
+    @Nullable private ConsumedPartitionGroup consumedPartitionGroup;
+
     private ConsumerVertexGroup(
             List<ExecutionVertexID> vertices, ResultPartitionType resultPartitionType) {
         this.vertices = vertices;
@@ -66,4 +76,13 @@ public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> {
     public ExecutionVertexID getFirst() {
         return iterator().next();
     }
+
+    public ConsumedPartitionGroup getConsumedPartitionGroup() {
+        return checkNotNull(consumedPartitionGroup, "ConsumedPartitionGroup is not properly set.");
+    }
+
+    public void setConsumedPartitionGroup(ConsumedPartitionGroup consumedPartitionGroup) {
+        checkState(this.consumedPartitionGroup == null);
+        this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup);
+    }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index 41422bb48b1..63ecd9d442b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
 import org.apache.flink.testutils.TestingUtils;
 import org.apache.flink.testutils.executor.TestExecutorExtension;
@@ -116,21 +117,24 @@ class EdgeManagerBuildUtilTest {
         ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
 
         // check consumers of the partitions
-        assertThat(partition1.getConsumerVertexGroups().get(0))
-                .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
-        assertThat(partition1.getConsumerVertexGroups().get(0))
-                .isEqualTo(partition1.getConsumerVertexGroups().get(0));
-        assertThat(partition3.getConsumerVertexGroups().get(0))
-                .isEqualTo(partition1.getConsumerVertexGroups().get(0));
+        ConsumerVertexGroup consumerVertexGroup = partition1.getConsumerVertexGroups().get(0);
+        assertThat(consumerVertexGroup).containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
+        assertThat(partition2.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup);
+        assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup);
 
         // check inputs of the execution vertices
-        assertThat(vertex1.getConsumedPartitionGroup(0))
+        ConsumedPartitionGroup consumedPartitionGroup = vertex1.getConsumedPartitionGroup(0);
+        assertThat(consumedPartitionGroup)
                 .containsExactlyInAnyOrder(
                         partition1.getPartitionId(),
                         partition2.getPartitionId(),
                         partition3.getPartitionId());
-        assertThat(vertex2.getConsumedPartitionGroup(0))
-                .isEqualTo(vertex1.getConsumedPartitionGroup(0));
+        assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup);
+
+        // check the consumerVertexGroup and consumedPartitionGroup are set to each other
+        assertThat(consumerVertexGroup.getConsumedPartitionGroup())
+                .isEqualTo(consumedPartitionGroup);
+        assertThat(consumedPartitionGroup.getConsumerVertexGroup()).isEqualTo(consumerVertexGroup);
     }
 
     @Test
@@ -186,25 +190,39 @@ class EdgeManagerBuildUtilTest {
         ExecutionVertex vertex4 = consumer.getTaskVertices()[3];
 
         // check consumers of the partitions
-        assertThat(partition1.getConsumerVertexGroups().get(0))
+        ConsumerVertexGroup consumerVertexGroup1 = partition1.getConsumerVertexGroups().get(0);
+        ConsumerVertexGroup consumerVertexGroup2 = partition2.getConsumerVertexGroups().get(0);
+        ConsumerVertexGroup consumerVertexGroup3 = partition4.getConsumerVertexGroups().get(0);
+        assertThat(consumerVertexGroup1)
                 .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
-        assertThat(partition2.getConsumerVertexGroups().get(0))
-                .containsExactlyInAnyOrder(vertex3.getID());
-        assertThat(partition3.getConsumerVertexGroups().get(0))
-                .isEqualTo(partition2.getConsumerVertexGroups().get(0));
-        assertThat(partition4.getConsumerVertexGroups().get(0))
-                .containsExactlyInAnyOrder(vertex4.getID());
+        assertThat(consumerVertexGroup2).containsExactlyInAnyOrder(vertex3.getID());
+        assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup2);
+        assertThat(consumerVertexGroup3).containsExactlyInAnyOrder(vertex4.getID());
 
         // check inputs of the execution vertices
-        assertThat(vertex1.getConsumedPartitionGroup(0))
-                .containsExactlyInAnyOrder(partition1.getPartitionId());
-        assertThat(vertex2.getConsumedPartitionGroup(0))
-                .isEqualTo(vertex1.getConsumedPartitionGroup(0));
-        assertThat(vertex3.getConsumedPartitionGroup(0))
+        ConsumedPartitionGroup consumedPartitionGroup1 = vertex1.getConsumedPartitionGroup(0);
+        ConsumedPartitionGroup consumedPartitionGroup2 = vertex3.getConsumedPartitionGroup(0);
+        ConsumedPartitionGroup consumedPartitionGroup3 = vertex4.getConsumedPartitionGroup(0);
+        assertThat(consumedPartitionGroup1).containsExactlyInAnyOrder(partition1.getPartitionId());
+        assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup1);
+        assertThat(consumedPartitionGroup2)
                 .containsExactlyInAnyOrder(
                         partition2.getPartitionId(), partition3.getPartitionId());
-        assertThat(vertex4.getConsumedPartitionGroup(0))
-                .containsExactlyInAnyOrder(partition4.getPartitionId());
+        assertThat(consumedPartitionGroup3).containsExactlyInAnyOrder(partition4.getPartitionId());
+
+        // check the consumerVertexGroups and consumedPartitionGroups are properly set
+        assertThat(consumerVertexGroup1.getConsumedPartitionGroup())
+                .isEqualTo(consumedPartitionGroup1);
+        assertThat(consumedPartitionGroup1.getConsumerVertexGroup())
+                .isEqualTo(consumerVertexGroup1);
+        assertThat(consumerVertexGroup2.getConsumedPartitionGroup())
+                .isEqualTo(consumedPartitionGroup2);
+        assertThat(consumedPartitionGroup2.getConsumerVertexGroup())
+                .isEqualTo(consumerVertexGroup2);
+        assertThat(consumerVertexGroup3.getConsumedPartitionGroup())
+                .isEqualTo(consumedPartitionGroup3);
+        assertThat(consumedPartitionGroup3.getConsumerVertexGroup())
+                .isEqualTo(consumerVertexGroup3);
     }
 
     private void testGetMaxNumEdgesToTarget(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
index 3bc9c404fae..b549da3e976 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
@@ -19,7 +19,6 @@
 package org.apache.flink.runtime.scheduler.strategy;
 
 import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.util.IterableUtils;
@@ -30,8 +29,6 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-import static org.apache.flink.util.Preconditions.checkNotNull;
-
 /** A simple scheduling execution vertex for testing purposes. */
 public class TestingSchedulingExecutionVertex implements SchedulingExecutionVertex {
 
@@ -47,17 +44,12 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
     private ExecutionState executionState;
 
     public TestingSchedulingExecutionVertex(
-            JobVertexID jobVertexId,
-            int subtaskIndex,
-            List<ConsumedPartitionGroup> consumedPartitionGroups,
-            Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
-                    resultPartitionsById,
-            ExecutionState executionState) {
+            JobVertexID jobVertexId, int subtaskIndex, ExecutionState executionState) {
 
         this.executionVertexId = new ExecutionVertexID(jobVertexId, subtaskIndex);
-        this.consumedPartitionGroups = checkNotNull(consumedPartitionGroups);
+        this.consumedPartitionGroups = new ArrayList<>();
         this.producedPartitions = new ArrayList<>();
-        this.resultPartitionsById = checkNotNull(resultPartitionsById);
+        this.resultPartitionsById = new HashMap<>();
         this.executionState = executionState;
     }
 
@@ -90,22 +82,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
         return consumedPartitionGroups;
     }
 
-    void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) {
-        final ConsumedPartitionGroup consumedPartitionGroup =
-                ConsumedPartitionGroup.fromSinglePartition(
-                        consumedPartition.getNumConsumers(),
-                        consumedPartition.getId(),
-                        consumedPartition.getResultType());
-
-        consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
-        if (consumedPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
-            consumedPartitionGroup.partitionFinished();
-        }
-
-        this.consumedPartitionGroups.add(consumedPartitionGroup);
-        this.resultPartitionsById.putIfAbsent(consumedPartition.getId(), consumedPartition);
-    }
-
     void addConsumedPartitionGroup(
             ConsumedPartitionGroup consumedPartitionGroup,
             Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
@@ -131,9 +107,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
     public static class Builder {
         private JobVertexID jobVertexId = new JobVertexID();
         private int subtaskIndex = 0;
-        private final List<ConsumedPartitionGroup> consumedPartitionGroups = new ArrayList<>();
-        private final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
-                resultPartitionsById = new HashMap<>();
         private ExecutionState executionState = ExecutionState.CREATED;
 
         Builder withExecutionVertexID(JobVertexID jobVertexId, int subtaskIndex) {
@@ -142,39 +115,13 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
             return this;
         }
 
-        public Builder withConsumedPartitionGroups(
-                List<ConsumedPartitionGroup> consumedPartitionGroups,
-                Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
-                        resultPartitionsById) {
-            this.resultPartitionsById.putAll(resultPartitionsById);
-            final ResultPartitionType resultType =
-                    resultPartitionsById.values().iterator().next().getResultType();
-
-            for (ConsumedPartitionGroup partitionGroup : consumedPartitionGroups) {
-                List<IntermediateResultPartitionID> partitionIds =
-                        new ArrayList<>(partitionGroup.size());
-                for (IntermediateResultPartitionID partitionId : partitionGroup) {
-                    partitionIds.add(partitionId);
-                }
-                this.consumedPartitionGroups.add(
-                        ConsumedPartitionGroup.fromMultiplePartitions(
-                                partitionGroup.getNumConsumers(), partitionIds, resultType));
-            }
-            return this;
-        }
-
         public Builder withExecutionState(ExecutionState executionState) {
             this.executionState = executionState;
             return this;
         }
 
         public TestingSchedulingExecutionVertex build() {
-            return new TestingSchedulingExecutionVertex(
-                    jobVertexId,
-                    subtaskIndex,
-                    consumedPartitionGroups,
-                    resultPartitionsById,
-                    executionState);
+            return new TestingSchedulingExecutionVertex(jobVertexId, subtaskIndex, executionState);
         }
     }
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
index 70759fd2b6c..77514ba57b2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
@@ -25,10 +25,8 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import javax.annotation.Nullable;
 
 import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.List;
-import java.util.stream.Collectors;
 
 import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
@@ -102,18 +100,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
         return Collections.unmodifiableList(consumedPartitionGroups);
     }
 
-    void addConsumerGroup(
-            Collection<TestingSchedulingExecutionVertex> consumerVertices,
-            ResultPartitionType resultPartitionType) {
+    void addConsumerGroup(ConsumerVertexGroup consumerVertexGroup) {
         checkState(this.consumerVertexGroup == null);
-
-        final ConsumerVertexGroup consumerVertexGroup =
-                ConsumerVertexGroup.fromMultipleVertices(
-                        consumerVertices.stream()
-                                .map(TestingSchedulingExecutionVertex::getId)
-                                .collect(Collectors.toList()),
-                        resultPartitionType);
-
         this.consumerVertexGroup = consumerVertexGroup;
     }
 
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
index ec86fa36d9b..b9653db3cf0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
@@ -38,6 +38,7 @@ import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
 
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /** A simple scheduling topology for testing purposes. */
@@ -189,17 +190,12 @@ public class TestingSchedulingTopology implements SchedulingTopology {
             TestingSchedulingExecutionVertex consumer,
             ResultPartitionType resultPartitionType) {
 
-        final TestingSchedulingResultPartition resultPartition =
-                new TestingSchedulingResultPartition.Builder()
-                        .withResultPartitionType(resultPartitionType)
-                        .build();
-
-        resultPartition.addConsumerGroup(
-                Collections.singleton(consumer), resultPartition.getResultType());
-        resultPartition.setProducer(producer);
-
-        producer.addProducedPartition(resultPartition);
-        consumer.addConsumedPartition(resultPartition);
+        connectConsumersToProducers(
+                Collections.singletonList(consumer),
+                Collections.singletonList(producer),
+                new IntermediateDataSetID(),
+                resultPartitionType,
+                ResultPartitionState.ALL_DATA_PRODUCED);
 
         updateVertexResultPartitions(producer);
         updateVertexResultPartitions(consumer);
@@ -223,6 +219,142 @@ public class TestingSchedulingTopology implements SchedulingTopology {
         return new ProducerConsumerAllToAllConnectionBuilder(producers, consumers);
     }
 
+    private static List<TestingSchedulingResultPartition> connectConsumersToProducers(
+            final List<TestingSchedulingExecutionVertex> consumers,
+            final List<TestingSchedulingExecutionVertex> producers,
+            final IntermediateDataSetID intermediateDataSetId,
+            final ResultPartitionType resultPartitionType,
+            final ResultPartitionState resultPartitionState) {
+
+        final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
+
+        final ConnectionResult connectionResult =
+                connectConsumersToProducersById(
+                        consumers.stream()
+                                .map(SchedulingExecutionVertex::getId)
+                                .collect(Collectors.toList()),
+                        producers.stream()
+                                .map(SchedulingExecutionVertex::getId)
+                                .collect(Collectors.toList()),
+                        intermediateDataSetId,
+                        resultPartitionType);
+
+        final ConsumedPartitionGroup consumedPartitionGroup =
+                connectionResult.getConsumedPartitionGroup();
+        final ConsumerVertexGroup consumerVertexGroup = connectionResult.getConsumerVertexGroup();
+
+        final TestingSchedulingResultPartition.Builder resultPartitionBuilder =
+                new TestingSchedulingResultPartition.Builder()
+                        .withIntermediateDataSetID(intermediateDataSetId)
+                        .withResultPartitionType(resultPartitionType)
+                        .withResultPartitionState(resultPartitionState);
+
+        for (int i = 0; i < producers.size(); i++) {
+            final TestingSchedulingExecutionVertex producer = producers.get(i);
+            final IntermediateResultPartitionID partitionId =
+                    connectionResult.getResultPartitions().get(i);
+            final TestingSchedulingResultPartition resultPartition =
+                    resultPartitionBuilder
+                            .withPartitionNum(partitionId.getPartitionNumber())
+                            .build();
+
+            producer.addProducedPartition(resultPartition);
+
+            resultPartition.setProducer(producer);
+            resultPartitions.add(resultPartition);
+            resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
+            resultPartition.addConsumerGroup(consumerVertexGroup);
+
+            if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
+                consumedPartitionGroup.partitionFinished();
+            }
+        }
+
+        final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
+                consumedPartitionById =
+                        resultPartitions.stream()
+                                .collect(
+                                        Collectors.toMap(
+                                                TestingSchedulingResultPartition::getId,
+                                                Function.identity()));
+        for (TestingSchedulingExecutionVertex consumer : consumers) {
+            consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById);
+        }
+
+        return resultPartitions;
+    }
+
+    public static ConnectionResult connectConsumersToProducersById(
+            final List<ExecutionVertexID> consumers,
+            final List<ExecutionVertexID> producers,
+            final IntermediateDataSetID intermediateDataSetId,
+            final ResultPartitionType resultPartitionType) {
+
+        final List<IntermediateResultPartitionID> resultPartitions = new ArrayList<>();
+        for (ExecutionVertexID producer : producers) {
+            final IntermediateResultPartitionID resultPartition =
+                    new IntermediateResultPartitionID(
+                            intermediateDataSetId, producer.getSubtaskIndex());
+            resultPartitions.add(resultPartition);
+        }
+
+        final ConsumedPartitionGroup consumedPartitionGroup =
+                createConsumedPartitionGroup(
+                        consumers.size(), resultPartitions, resultPartitionType);
+        final ConsumerVertexGroup consumerVertexGroup =
+                createConsumerVertexGroup(consumers, resultPartitionType);
+
+        consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup);
+        consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup);
+
+        return new ConnectionResult(resultPartitions, consumedPartitionGroup, consumerVertexGroup);
+    }
+
+    private static ConsumedPartitionGroup createConsumedPartitionGroup(
+            final int numConsumers,
+            final List<IntermediateResultPartitionID> consumedPartitions,
+            final ResultPartitionType resultPartitionType) {
+        return ConsumedPartitionGroup.fromMultiplePartitions(
+                numConsumers, consumedPartitions, resultPartitionType);
+    }
+
+    private static ConsumerVertexGroup createConsumerVertexGroup(
+            final List<ExecutionVertexID> consumers,
+            final ResultPartitionType resultPartitionType) {
+        return ConsumerVertexGroup.fromMultipleVertices(consumers, resultPartitionType);
+    }
+
+    /**
+     * The result of connecting a set of consumers to their producers, including the created result
+     * partitions and the consumption groups.
+     */
+    public static class ConnectionResult {
+        private final List<IntermediateResultPartitionID> resultPartitions;
+        private final ConsumedPartitionGroup consumedPartitionGroup;
+        private final ConsumerVertexGroup consumerVertexGroup;
+
+        public ConnectionResult(
+                final List<IntermediateResultPartitionID> resultPartitions,
+                final ConsumedPartitionGroup consumedPartitionGroup,
+                final ConsumerVertexGroup consumerVertexGroup) {
+            this.resultPartitions = checkNotNull(resultPartitions);
+            this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup);
+            this.consumerVertexGroup = checkNotNull(consumerVertexGroup);
+        }
+
+        public List<IntermediateResultPartitionID> getResultPartitions() {
+            return resultPartitions;
+        }
+
+        public ConsumedPartitionGroup getConsumedPartitionGroup() {
+            return consumedPartitionGroup;
+        }
+
+        public ConsumerVertexGroup getConsumerVertexGroup() {
+            return consumerVertexGroup;
+        }
+    }
+
     /** Builder for {@link TestingSchedulingResultPartition}. */
     public abstract class ProducerConsumerConnectionBuilder {
 
@@ -265,11 +397,6 @@ public class TestingSchedulingTopology implements SchedulingTopology {
             return resultPartitions;
         }
 
-        TestingSchedulingResultPartition.Builder initTestingSchedulingResultPartitionBuilder() {
-            return new TestingSchedulingResultPartition.Builder()
-                    .withResultPartitionType(resultPartitionType);
-        }
-
         protected abstract List<TestingSchedulingResultPartition> connect();
     }
 
@@ -292,25 +419,15 @@ public class TestingSchedulingTopology implements SchedulingTopology {
         protected List<TestingSchedulingResultPartition> connect() {
             final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
             final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID();
-
             for (int idx = 0; idx < producers.size(); idx++) {
-                final TestingSchedulingExecutionVertex producer = producers.get(idx);
-                final TestingSchedulingExecutionVertex consumer = consumers.get(idx);
-
-                final TestingSchedulingResultPartition resultPartition =
-                        initTestingSchedulingResultPartitionBuilder()
-                                .withIntermediateDataSetID(intermediateDataSetId)
-                                .withResultPartitionState(resultPartitionState)
-                                .withPartitionNum(idx)
-                                .build();
-                resultPartition.setProducer(producer);
-                producer.addProducedPartition(resultPartition);
-                consumer.addConsumedPartition(resultPartition);
-                resultPartition.addConsumerGroup(
-                        Collections.singleton(consumer), resultPartitionType);
-                resultPartitions.add(resultPartition);
+                resultPartitions.addAll(
+                        connectConsumersToProducers(
+                                Collections.singletonList(consumers.get(idx)),
+                                Collections.singletonList(producers.get(idx)),
+                                intermediateDataSetId,
+                                resultPartitionType,
+                                resultPartitionState));
             }
-
             return resultPartitions;
         }
     }
@@ -330,53 +447,12 @@ public class TestingSchedulingTopology implements SchedulingTopology {
 
         @Override
         protected List<TestingSchedulingResultPartition> connect() {
-            final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
-            final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID();
-
-            TestingSchedulingResultPartition.Builder resultPartitionBuilder =
-                    initTestingSchedulingResultPartitionBuilder()
-                            .withIntermediateDataSetID(intermediateDataSetId)
-                            .withResultPartitionState(resultPartitionState);
-
-            int partitionNum = 0;
-
-            for (TestingSchedulingExecutionVertex producer : producers) {
-
-                final TestingSchedulingResultPartition resultPartition =
-                        resultPartitionBuilder.withPartitionNum(partitionNum++).build();
-                resultPartition.setProducer(producer);
-                producer.addProducedPartition(resultPartition);
-
-                resultPartition.addConsumerGroup(consumers, resultPartitionType);
-                resultPartitions.add(resultPartition);
-            }
-
-            ConsumedPartitionGroup consumedPartitionGroup =
-                    ConsumedPartitionGroup.fromMultiplePartitions(
-                            consumers.size(),
-                            resultPartitions.stream()
-                                    .map(TestingSchedulingResultPartition::getId)
-                                    .collect(Collectors.toList()),
-                            resultPartitions.get(0).getResultType());
-            Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
-                    consumedPartitionById =
-                            resultPartitions.stream()
-                                    .collect(
-                                            Collectors.toMap(
-                                                    TestingSchedulingResultPartition::getId,
-                                                    Function.identity()));
-            for (TestingSchedulingExecutionVertex consumer : consumers) {
-                consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById);
-            }
-
-            for (TestingSchedulingResultPartition resultPartition : resultPartitions) {
-                resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
-                if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
-                    consumedPartitionGroup.partitionFinished();
-                }
-            }
-
-            return resultPartitions;
+            return connectConsumersToProducers(
+                    consumers,
+                    producers,
+                    new IntermediateDataSetID(),
+                    resultPartitionType,
+                    resultPartitionState);
         }
     }