You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2023/01/30 08:46:12 UTC
[flink] 02/03: [FLINK-15325][coordination] Set the ConsumedPartitionGroup/ConsumerVertexGroup to its corresponding ConsumerVertexGroup/ConsumedPartitionGroup
This is an automated email from the ASF dual-hosted git repository.
zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit c0bfb0b04bb38411c390813fea1ff93f6638f409
Author: Zhu Zhu <re...@gmail.com>
AuthorDate: Mon Jan 30 12:02:13 2023 +0800
[FLINK-15325][coordination] Set the ConsumedPartitionGroup/ConsumerVertexGroup to its corresponding ConsumerVertexGroup/ConsumedPartitionGroup
---
.../executiongraph/EdgeManagerBuildUtil.java | 3 +
.../scheduler/strategy/ConsumedPartitionGroup.java | 20 +-
.../scheduler/strategy/ConsumerVertexGroup.java | 21 +-
.../executiongraph/EdgeManagerBuildUtilTest.java | 64 ++++--
.../strategy/TestingSchedulingExecutionVertex.java | 61 +-----
.../strategy/TestingSchedulingResultPartition.java | 14 +-
.../strategy/TestingSchedulingTopology.java | 236 ++++++++++++++-------
7 files changed, 244 insertions(+), 175 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index a3613eba27c..889d7a39ae1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -174,6 +174,9 @@ public class EdgeManagerBuildUtil {
for (IntermediateResultPartition partition : partitions) {
partition.addConsumers(consumerVertexGroup);
}
+
+ consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup);
+ consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup);
}
private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
index c3d32fa99fb..e4440d7575c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
@@ -23,14 +23,21 @@ import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.util.Preconditions;
+import javax.annotation.Nullable;
+
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
-/** Group of consumed {@link IntermediateResultPartitionID}s. */
+/**
+ * Group of consumed {@link IntermediateResultPartitionID}s. One such a group corresponds to one
+ * {@link ConsumerVertexGroup}.
+ */
public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartitionID> {
private final List<IntermediateResultPartitionID> resultPartitions;
@@ -44,6 +51,8 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
/** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */
private final int numConsumers;
+ @Nullable private ConsumerVertexGroup consumerVertexGroup;
+
private ConsumedPartitionGroup(
int numConsumers,
List<IntermediateResultPartitionID> resultPartitions,
@@ -130,4 +139,13 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
public ResultPartitionType getResultPartitionType() {
return resultPartitionType;
}
+
+ public ConsumerVertexGroup getConsumerVertexGroup() {
+ return checkNotNull(consumerVertexGroup, "ConsumerVertexGroup is not properly set.");
+ }
+
+ public void setConsumerVertexGroup(ConsumerVertexGroup consumerVertexGroup) {
+ checkState(this.consumerVertexGroup == null);
+ this.consumerVertexGroup = checkNotNull(consumerVertexGroup);
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
index fb8b3f1951c..9939206b7d9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumerVertexGroup.java
@@ -20,16 +20,26 @@ package org.apache.flink.runtime.scheduler.strategy;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import javax.annotation.Nullable;
+
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
-/** Group of consumer {@link ExecutionVertexID}s. */
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/**
+ * Group of consumer {@link ExecutionVertexID}s. One such a group corresponds to one {@link
+ * ConsumedPartitionGroup}.
+ */
public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> {
private final List<ExecutionVertexID> vertices;
private final ResultPartitionType resultPartitionType;
+ @Nullable private ConsumedPartitionGroup consumedPartitionGroup;
+
private ConsumerVertexGroup(
List<ExecutionVertexID> vertices, ResultPartitionType resultPartitionType) {
this.vertices = vertices;
@@ -66,4 +76,13 @@ public class ConsumerVertexGroup implements Iterable<ExecutionVertexID> {
public ExecutionVertexID getFirst() {
return iterator().next();
}
+
+ public ConsumedPartitionGroup getConsumedPartitionGroup() {
+ return checkNotNull(consumedPartitionGroup, "ConsumedPartitionGroup is not properly set.");
+ }
+
+ public void setConsumedPartitionGroup(ConsumedPartitionGroup consumedPartitionGroup) {
+ checkState(this.consumedPartitionGroup == null);
+ this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup);
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index 41422bb48b1..63ecd9d442b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
@@ -116,21 +117,24 @@ class EdgeManagerBuildUtilTest {
ExecutionVertex vertex2 = consumer.getTaskVertices()[1];
// check consumers of the partitions
- assertThat(partition1.getConsumerVertexGroups().get(0))
- .containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
- assertThat(partition1.getConsumerVertexGroups().get(0))
- .isEqualTo(partition1.getConsumerVertexGroups().get(0));
- assertThat(partition3.getConsumerVertexGroups().get(0))
- .isEqualTo(partition1.getConsumerVertexGroups().get(0));
+ ConsumerVertexGroup consumerVertexGroup = partition1.getConsumerVertexGroups().get(0);
+ assertThat(consumerVertexGroup).containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
+ assertThat(partition2.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup);
+ assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup);
// check inputs of the execution vertices
- assertThat(vertex1.getConsumedPartitionGroup(0))
+ ConsumedPartitionGroup consumedPartitionGroup = vertex1.getConsumedPartitionGroup(0);
+ assertThat(consumedPartitionGroup)
.containsExactlyInAnyOrder(
partition1.getPartitionId(),
partition2.getPartitionId(),
partition3.getPartitionId());
- assertThat(vertex2.getConsumedPartitionGroup(0))
- .isEqualTo(vertex1.getConsumedPartitionGroup(0));
+ assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup);
+
+ // check the consumerVertexGroup and consumedPartitionGroup are set to each other
+ assertThat(consumerVertexGroup.getConsumedPartitionGroup())
+ .isEqualTo(consumedPartitionGroup);
+ assertThat(consumedPartitionGroup.getConsumerVertexGroup()).isEqualTo(consumerVertexGroup);
}
@Test
@@ -186,25 +190,39 @@ class EdgeManagerBuildUtilTest {
ExecutionVertex vertex4 = consumer.getTaskVertices()[3];
// check consumers of the partitions
- assertThat(partition1.getConsumerVertexGroups().get(0))
+ ConsumerVertexGroup consumerVertexGroup1 = partition1.getConsumerVertexGroups().get(0);
+ ConsumerVertexGroup consumerVertexGroup2 = partition2.getConsumerVertexGroups().get(0);
+ ConsumerVertexGroup consumerVertexGroup3 = partition4.getConsumerVertexGroups().get(0);
+ assertThat(consumerVertexGroup1)
.containsExactlyInAnyOrder(vertex1.getID(), vertex2.getID());
- assertThat(partition2.getConsumerVertexGroups().get(0))
- .containsExactlyInAnyOrder(vertex3.getID());
- assertThat(partition3.getConsumerVertexGroups().get(0))
- .isEqualTo(partition2.getConsumerVertexGroups().get(0));
- assertThat(partition4.getConsumerVertexGroups().get(0))
- .containsExactlyInAnyOrder(vertex4.getID());
+ assertThat(consumerVertexGroup2).containsExactlyInAnyOrder(vertex3.getID());
+ assertThat(partition3.getConsumerVertexGroups().get(0)).isEqualTo(consumerVertexGroup2);
+ assertThat(consumerVertexGroup3).containsExactlyInAnyOrder(vertex4.getID());
// check inputs of the execution vertices
- assertThat(vertex1.getConsumedPartitionGroup(0))
- .containsExactlyInAnyOrder(partition1.getPartitionId());
- assertThat(vertex2.getConsumedPartitionGroup(0))
- .isEqualTo(vertex1.getConsumedPartitionGroup(0));
- assertThat(vertex3.getConsumedPartitionGroup(0))
+ ConsumedPartitionGroup consumedPartitionGroup1 = vertex1.getConsumedPartitionGroup(0);
+ ConsumedPartitionGroup consumedPartitionGroup2 = vertex3.getConsumedPartitionGroup(0);
+ ConsumedPartitionGroup consumedPartitionGroup3 = vertex4.getConsumedPartitionGroup(0);
+ assertThat(consumedPartitionGroup1).containsExactlyInAnyOrder(partition1.getPartitionId());
+ assertThat(vertex2.getConsumedPartitionGroup(0)).isEqualTo(consumedPartitionGroup1);
+ assertThat(consumedPartitionGroup2)
.containsExactlyInAnyOrder(
partition2.getPartitionId(), partition3.getPartitionId());
- assertThat(vertex4.getConsumedPartitionGroup(0))
- .containsExactlyInAnyOrder(partition4.getPartitionId());
+ assertThat(consumedPartitionGroup3).containsExactlyInAnyOrder(partition4.getPartitionId());
+
+ // check the consumerVertexGroups and consumedPartitionGroups are properly set
+ assertThat(consumerVertexGroup1.getConsumedPartitionGroup())
+ .isEqualTo(consumedPartitionGroup1);
+ assertThat(consumedPartitionGroup1.getConsumerVertexGroup())
+ .isEqualTo(consumerVertexGroup1);
+ assertThat(consumerVertexGroup2.getConsumedPartitionGroup())
+ .isEqualTo(consumedPartitionGroup2);
+ assertThat(consumedPartitionGroup2.getConsumerVertexGroup())
+ .isEqualTo(consumerVertexGroup2);
+ assertThat(consumerVertexGroup3.getConsumedPartitionGroup())
+ .isEqualTo(consumedPartitionGroup3);
+ assertThat(consumedPartitionGroup3.getConsumerVertexGroup())
+ .isEqualTo(consumerVertexGroup3);
}
private void testGetMaxNumEdgesToTarget(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
index 3bc9c404fae..b549da3e976 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
@@ -19,7 +19,6 @@
package org.apache.flink.runtime.scheduler.strategy;
import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.util.IterableUtils;
@@ -30,8 +29,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import static org.apache.flink.util.Preconditions.checkNotNull;
-
/** A simple scheduling execution vertex for testing purposes. */
public class TestingSchedulingExecutionVertex implements SchedulingExecutionVertex {
@@ -47,17 +44,12 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
private ExecutionState executionState;
public TestingSchedulingExecutionVertex(
- JobVertexID jobVertexId,
- int subtaskIndex,
- List<ConsumedPartitionGroup> consumedPartitionGroups,
- Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
- resultPartitionsById,
- ExecutionState executionState) {
+ JobVertexID jobVertexId, int subtaskIndex, ExecutionState executionState) {
this.executionVertexId = new ExecutionVertexID(jobVertexId, subtaskIndex);
- this.consumedPartitionGroups = checkNotNull(consumedPartitionGroups);
+ this.consumedPartitionGroups = new ArrayList<>();
this.producedPartitions = new ArrayList<>();
- this.resultPartitionsById = checkNotNull(resultPartitionsById);
+ this.resultPartitionsById = new HashMap<>();
this.executionState = executionState;
}
@@ -90,22 +82,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
return consumedPartitionGroups;
}
- void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) {
- final ConsumedPartitionGroup consumedPartitionGroup =
- ConsumedPartitionGroup.fromSinglePartition(
- consumedPartition.getNumConsumers(),
- consumedPartition.getId(),
- consumedPartition.getResultType());
-
- consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
- if (consumedPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
- consumedPartitionGroup.partitionFinished();
- }
-
- this.consumedPartitionGroups.add(consumedPartitionGroup);
- this.resultPartitionsById.putIfAbsent(consumedPartition.getId(), consumedPartition);
- }
-
void addConsumedPartitionGroup(
ConsumedPartitionGroup consumedPartitionGroup,
Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
@@ -131,9 +107,6 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
public static class Builder {
private JobVertexID jobVertexId = new JobVertexID();
private int subtaskIndex = 0;
- private final List<ConsumedPartitionGroup> consumedPartitionGroups = new ArrayList<>();
- private final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
- resultPartitionsById = new HashMap<>();
private ExecutionState executionState = ExecutionState.CREATED;
Builder withExecutionVertexID(JobVertexID jobVertexId, int subtaskIndex) {
@@ -142,39 +115,13 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
return this;
}
- public Builder withConsumedPartitionGroups(
- List<ConsumedPartitionGroup> consumedPartitionGroups,
- Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
- resultPartitionsById) {
- this.resultPartitionsById.putAll(resultPartitionsById);
- final ResultPartitionType resultType =
- resultPartitionsById.values().iterator().next().getResultType();
-
- for (ConsumedPartitionGroup partitionGroup : consumedPartitionGroups) {
- List<IntermediateResultPartitionID> partitionIds =
- new ArrayList<>(partitionGroup.size());
- for (IntermediateResultPartitionID partitionId : partitionGroup) {
- partitionIds.add(partitionId);
- }
- this.consumedPartitionGroups.add(
- ConsumedPartitionGroup.fromMultiplePartitions(
- partitionGroup.getNumConsumers(), partitionIds, resultType));
- }
- return this;
- }
-
public Builder withExecutionState(ExecutionState executionState) {
this.executionState = executionState;
return this;
}
public TestingSchedulingExecutionVertex build() {
- return new TestingSchedulingExecutionVertex(
- jobVertexId,
- subtaskIndex,
- consumedPartitionGroups,
- resultPartitionsById,
- executionState);
+ return new TestingSchedulingExecutionVertex(jobVertexId, subtaskIndex, executionState);
}
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
index 70759fd2b6c..77514ba57b2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
@@ -25,10 +25,8 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import javax.annotation.Nullable;
import java.util.ArrayList;
-import java.util.Collection;
import java.util.Collections;
import java.util.List;
-import java.util.stream.Collectors;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
@@ -102,18 +100,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
return Collections.unmodifiableList(consumedPartitionGroups);
}
- void addConsumerGroup(
- Collection<TestingSchedulingExecutionVertex> consumerVertices,
- ResultPartitionType resultPartitionType) {
+ void addConsumerGroup(ConsumerVertexGroup consumerVertexGroup) {
checkState(this.consumerVertexGroup == null);
-
- final ConsumerVertexGroup consumerVertexGroup =
- ConsumerVertexGroup.fromMultipleVertices(
- consumerVertices.stream()
- .map(TestingSchedulingExecutionVertex::getId)
- .collect(Collectors.toList()),
- resultPartitionType);
-
this.consumerVertexGroup = consumerVertexGroup;
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
index ec86fa36d9b..b9653db3cf0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
@@ -38,6 +38,7 @@ import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
+import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
/** A simple scheduling topology for testing purposes. */
@@ -189,17 +190,12 @@ public class TestingSchedulingTopology implements SchedulingTopology {
TestingSchedulingExecutionVertex consumer,
ResultPartitionType resultPartitionType) {
- final TestingSchedulingResultPartition resultPartition =
- new TestingSchedulingResultPartition.Builder()
- .withResultPartitionType(resultPartitionType)
- .build();
-
- resultPartition.addConsumerGroup(
- Collections.singleton(consumer), resultPartition.getResultType());
- resultPartition.setProducer(producer);
-
- producer.addProducedPartition(resultPartition);
- consumer.addConsumedPartition(resultPartition);
+ connectConsumersToProducers(
+ Collections.singletonList(consumer),
+ Collections.singletonList(producer),
+ new IntermediateDataSetID(),
+ resultPartitionType,
+ ResultPartitionState.ALL_DATA_PRODUCED);
updateVertexResultPartitions(producer);
updateVertexResultPartitions(consumer);
@@ -223,6 +219,142 @@ public class TestingSchedulingTopology implements SchedulingTopology {
return new ProducerConsumerAllToAllConnectionBuilder(producers, consumers);
}
+ private static List<TestingSchedulingResultPartition> connectConsumersToProducers(
+ final List<TestingSchedulingExecutionVertex> consumers,
+ final List<TestingSchedulingExecutionVertex> producers,
+ final IntermediateDataSetID intermediateDataSetId,
+ final ResultPartitionType resultPartitionType,
+ final ResultPartitionState resultPartitionState) {
+
+ final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
+
+ final ConnectionResult connectionResult =
+ connectConsumersToProducersById(
+ consumers.stream()
+ .map(SchedulingExecutionVertex::getId)
+ .collect(Collectors.toList()),
+ producers.stream()
+ .map(SchedulingExecutionVertex::getId)
+ .collect(Collectors.toList()),
+ intermediateDataSetId,
+ resultPartitionType);
+
+ final ConsumedPartitionGroup consumedPartitionGroup =
+ connectionResult.getConsumedPartitionGroup();
+ final ConsumerVertexGroup consumerVertexGroup = connectionResult.getConsumerVertexGroup();
+
+ final TestingSchedulingResultPartition.Builder resultPartitionBuilder =
+ new TestingSchedulingResultPartition.Builder()
+ .withIntermediateDataSetID(intermediateDataSetId)
+ .withResultPartitionType(resultPartitionType)
+ .withResultPartitionState(resultPartitionState);
+
+ for (int i = 0; i < producers.size(); i++) {
+ final TestingSchedulingExecutionVertex producer = producers.get(i);
+ final IntermediateResultPartitionID partitionId =
+ connectionResult.getResultPartitions().get(i);
+ final TestingSchedulingResultPartition resultPartition =
+ resultPartitionBuilder
+ .withPartitionNum(partitionId.getPartitionNumber())
+ .build();
+
+ producer.addProducedPartition(resultPartition);
+
+ resultPartition.setProducer(producer);
+ resultPartitions.add(resultPartition);
+ resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
+ resultPartition.addConsumerGroup(consumerVertexGroup);
+
+ if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
+ consumedPartitionGroup.partitionFinished();
+ }
+ }
+
+ final Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
+ consumedPartitionById =
+ resultPartitions.stream()
+ .collect(
+ Collectors.toMap(
+ TestingSchedulingResultPartition::getId,
+ Function.identity()));
+ for (TestingSchedulingExecutionVertex consumer : consumers) {
+ consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById);
+ }
+
+ return resultPartitions;
+ }
+
+ public static ConnectionResult connectConsumersToProducersById(
+ final List<ExecutionVertexID> consumers,
+ final List<ExecutionVertexID> producers,
+ final IntermediateDataSetID intermediateDataSetId,
+ final ResultPartitionType resultPartitionType) {
+
+ final List<IntermediateResultPartitionID> resultPartitions = new ArrayList<>();
+ for (ExecutionVertexID producer : producers) {
+ final IntermediateResultPartitionID resultPartition =
+ new IntermediateResultPartitionID(
+ intermediateDataSetId, producer.getSubtaskIndex());
+ resultPartitions.add(resultPartition);
+ }
+
+ final ConsumedPartitionGroup consumedPartitionGroup =
+ createConsumedPartitionGroup(
+ consumers.size(), resultPartitions, resultPartitionType);
+ final ConsumerVertexGroup consumerVertexGroup =
+ createConsumerVertexGroup(consumers, resultPartitionType);
+
+ consumedPartitionGroup.setConsumerVertexGroup(consumerVertexGroup);
+ consumerVertexGroup.setConsumedPartitionGroup(consumedPartitionGroup);
+
+ return new ConnectionResult(resultPartitions, consumedPartitionGroup, consumerVertexGroup);
+ }
+
+ private static ConsumedPartitionGroup createConsumedPartitionGroup(
+ final int numConsumers,
+ final List<IntermediateResultPartitionID> consumedPartitions,
+ final ResultPartitionType resultPartitionType) {
+ return ConsumedPartitionGroup.fromMultiplePartitions(
+ numConsumers, consumedPartitions, resultPartitionType);
+ }
+
+ private static ConsumerVertexGroup createConsumerVertexGroup(
+ final List<ExecutionVertexID> consumers,
+ final ResultPartitionType resultPartitionType) {
+ return ConsumerVertexGroup.fromMultipleVertices(consumers, resultPartitionType);
+ }
+
+ /**
+ * The result of connecting a set of consumers to their producers, including the created result
+ * partitions and the consumption groups.
+ */
+ public static class ConnectionResult {
+ private final List<IntermediateResultPartitionID> resultPartitions;
+ private final ConsumedPartitionGroup consumedPartitionGroup;
+ private final ConsumerVertexGroup consumerVertexGroup;
+
+ public ConnectionResult(
+ final List<IntermediateResultPartitionID> resultPartitions,
+ final ConsumedPartitionGroup consumedPartitionGroup,
+ final ConsumerVertexGroup consumerVertexGroup) {
+ this.resultPartitions = checkNotNull(resultPartitions);
+ this.consumedPartitionGroup = checkNotNull(consumedPartitionGroup);
+ this.consumerVertexGroup = checkNotNull(consumerVertexGroup);
+ }
+
+ public List<IntermediateResultPartitionID> getResultPartitions() {
+ return resultPartitions;
+ }
+
+ public ConsumedPartitionGroup getConsumedPartitionGroup() {
+ return consumedPartitionGroup;
+ }
+
+ public ConsumerVertexGroup getConsumerVertexGroup() {
+ return consumerVertexGroup;
+ }
+ }
+
/** Builder for {@link TestingSchedulingResultPartition}. */
public abstract class ProducerConsumerConnectionBuilder {
@@ -265,11 +397,6 @@ public class TestingSchedulingTopology implements SchedulingTopology {
return resultPartitions;
}
- TestingSchedulingResultPartition.Builder initTestingSchedulingResultPartitionBuilder() {
- return new TestingSchedulingResultPartition.Builder()
- .withResultPartitionType(resultPartitionType);
- }
-
protected abstract List<TestingSchedulingResultPartition> connect();
}
@@ -292,25 +419,15 @@ public class TestingSchedulingTopology implements SchedulingTopology {
protected List<TestingSchedulingResultPartition> connect() {
final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID();
-
for (int idx = 0; idx < producers.size(); idx++) {
- final TestingSchedulingExecutionVertex producer = producers.get(idx);
- final TestingSchedulingExecutionVertex consumer = consumers.get(idx);
-
- final TestingSchedulingResultPartition resultPartition =
- initTestingSchedulingResultPartitionBuilder()
- .withIntermediateDataSetID(intermediateDataSetId)
- .withResultPartitionState(resultPartitionState)
- .withPartitionNum(idx)
- .build();
- resultPartition.setProducer(producer);
- producer.addProducedPartition(resultPartition);
- consumer.addConsumedPartition(resultPartition);
- resultPartition.addConsumerGroup(
- Collections.singleton(consumer), resultPartitionType);
- resultPartitions.add(resultPartition);
+ resultPartitions.addAll(
+ connectConsumersToProducers(
+ Collections.singletonList(consumers.get(idx)),
+ Collections.singletonList(producers.get(idx)),
+ intermediateDataSetId,
+ resultPartitionType,
+ resultPartitionState));
}
-
return resultPartitions;
}
}
@@ -330,53 +447,12 @@ public class TestingSchedulingTopology implements SchedulingTopology {
@Override
protected List<TestingSchedulingResultPartition> connect() {
- final List<TestingSchedulingResultPartition> resultPartitions = new ArrayList<>();
- final IntermediateDataSetID intermediateDataSetId = new IntermediateDataSetID();
-
- TestingSchedulingResultPartition.Builder resultPartitionBuilder =
- initTestingSchedulingResultPartitionBuilder()
- .withIntermediateDataSetID(intermediateDataSetId)
- .withResultPartitionState(resultPartitionState);
-
- int partitionNum = 0;
-
- for (TestingSchedulingExecutionVertex producer : producers) {
-
- final TestingSchedulingResultPartition resultPartition =
- resultPartitionBuilder.withPartitionNum(partitionNum++).build();
- resultPartition.setProducer(producer);
- producer.addProducedPartition(resultPartition);
-
- resultPartition.addConsumerGroup(consumers, resultPartitionType);
- resultPartitions.add(resultPartition);
- }
-
- ConsumedPartitionGroup consumedPartitionGroup =
- ConsumedPartitionGroup.fromMultiplePartitions(
- consumers.size(),
- resultPartitions.stream()
- .map(TestingSchedulingResultPartition::getId)
- .collect(Collectors.toList()),
- resultPartitions.get(0).getResultType());
- Map<IntermediateResultPartitionID, TestingSchedulingResultPartition>
- consumedPartitionById =
- resultPartitions.stream()
- .collect(
- Collectors.toMap(
- TestingSchedulingResultPartition::getId,
- Function.identity()));
- for (TestingSchedulingExecutionVertex consumer : consumers) {
- consumer.addConsumedPartitionGroup(consumedPartitionGroup, consumedPartitionById);
- }
-
- for (TestingSchedulingResultPartition resultPartition : resultPartitions) {
- resultPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
- if (resultPartition.getState() == ResultPartitionState.ALL_DATA_PRODUCED) {
- consumedPartitionGroup.partitionFinished();
- }
- }
-
- return resultPartitions;
+ return connectConsumersToProducers(
+ consumers,
+ producers,
+ new IntermediateDataSetID(),
+ resultPartitionType,
+ resultPartitionState);
}
}