You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by yi...@apache.org on 2022/08/08 14:58:42 UTC
[flink] 02/02: [FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side
This is an automated email from the ASF dual-hosted git repository.
yingjie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
commit 72405361610f1576e0f57ea4e957fafbbbaf0910
Author: kevin.cyj <ke...@alibaba-inc.com>
AuthorDate: Mon Jul 25 16:03:28 2022 +0800
[FLINK-28663][runtime] Allow multiple downstream consumer job vertices sharing the same intermediate dataset at scheduler side
This closes #20350.
---
.../optimizer/plantranslate/JobGraphGenerator.java | 5 +-
.../TaskDeploymentDescriptorFactory.java | 9 +-
.../executiongraph/DefaultExecutionGraph.java | 15 +-
.../flink/runtime/executiongraph/EdgeManager.java | 24 +++-
.../executiongraph/EdgeManagerBuildUtil.java | 21 +--
.../flink/runtime/executiongraph/Execution.java | 61 +++++----
.../runtime/executiongraph/IntermediateResult.java | 70 ++++++++--
.../IntermediateResultPartition.java | 58 +++++---
.../RestartPipelinedRegionFailoverStrategy.java | 12 +-
.../SchedulingPipelinedRegionComputeUtil.java | 32 ++---
.../io/network/NettyShuffleEnvironment.java | 36 +++--
.../runtime/jobgraph/IntermediateDataSet.java | 40 ++++--
.../org/apache/flink/runtime/jobgraph/JobEdge.java | 48 ++-----
.../apache/flink/runtime/jobgraph/JobGraph.java | 5 +-
.../apache/flink/runtime/jobgraph/JobVertex.java | 41 +++---
.../SsgNetworkMemoryCalculationUtils.java | 30 ++--
.../adapter/DefaultExecutionTopology.java | 10 +-
.../scheduler/adapter/DefaultResultPartition.java | 11 +-
.../scheduler/strategy/ConsumedPartitionGroup.java | 19 ++-
.../strategy/SchedulingResultPartition.java | 7 +-
.../strategy/VertexwiseSchedulingStrategy.java | 10 +-
.../BlockingResultPartitionReleaseTest.java | 151 +++++++++++++++++++++
.../DefaultExecutionGraphConstructionTest.java | 41 ++++++
.../executiongraph/EdgeManagerBuildUtilTest.java | 2 +-
.../runtime/executiongraph/EdgeManagerTest.java | 88 +++++++++---
.../executiongraph/ExecutionGraphTestUtils.java | 19 +++
.../executiongraph/ExecutionJobVertexTest.java | 2 +-
.../IntermediateResultPartitionTest.java | 62 ++++++++-
.../RemoveCachedShuffleDescriptorTest.java | 133 ++++--------------
.../partition/NoOpJobMasterPartitionTracker.java | 5 +-
.../flink/runtime/jobgraph/JobTaskVertexTest.java | 43 +++++-
.../jobmaster/JobIntermediateDatasetReuseTest.java | 3 +-
.../runtime/scheduler/SchedulerTestingUtils.java | 137 +++++++++++++++++++
.../adapter/DefaultExecutionTopologyTest.java | 19 ++-
.../adapter/DefaultExecutionVertexTest.java | 1 +
.../adapter/DefaultResultPartitionTest.java | 20 ++-
.../adaptivebatch/AdaptiveBatchSchedulerTest.java | 2 +-
.../forwardgroup/ForwardGroupComputeUtilTest.java | 12 +-
.../strategy/TestingSchedulingExecutionVertex.java | 7 +-
.../strategy/TestingSchedulingResultPartition.java | 9 +-
.../strategy/TestingSchedulingTopology.java | 1 +
.../api/graph/StreamingJobGraphGenerator.java | 7 +-
.../validation/TestJobDataFlowValidator.java | 34 ++---
43 files changed, 958 insertions(+), 404 deletions(-)
diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
index 0da624f1e7f..887da47ed0c 100644
--- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
+++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
@@ -1246,7 +1246,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
predecessorVertex != null,
"Bug: Chained task has not been assigned its containing vertex when connecting.");
- predecessorVertex.createAndAddResultDataSet(
+ predecessorVertex.getOrCreateResultDataSet(
// use specified intermediateDataSetID
new IntermediateDataSetID(
((BlockingShuffleOutputFormat) userCodeObject).getIntermediateDataSetId()),
@@ -1326,7 +1326,7 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
JobEdge edge =
targetVertex.connectNewDataSetAsInput(
- sourceVertex, distributionPattern, resultType);
+ sourceVertex, distributionPattern, resultType, isBroadcast);
// -------------- configure the source task's ship strategy strategies in task config
// --------------
@@ -1403,7 +1403,6 @@ public class JobGraphGenerator implements Visitor<PlanNode> {
channel.getTempMode() == TempMode.NONE ? null : channel.getTempMode().toString();
edge.setShipStrategyName(shipStrategy);
- edge.setBroadcast(isBroadcast);
edge.setForward(channel.getShipStrategy() == ShipStrategyType.FORWARD);
edge.setPreProcessingOperationName(localStrategy);
edge.setOperatorLevelCachingDescription(caching);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index a2c45ed0c47..ba10dd59de8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -133,7 +133,9 @@ public class TaskDeploymentDescriptorFactory {
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
SubpartitionIndexRange consumedSubpartitionRange =
computeConsumedSubpartitionRange(
- resultPartition, executionId.getSubtaskIndex());
+ consumedPartitionGroup.getNumConsumers(),
+ resultPartition,
+ executionId.getSubtaskIndex());
IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
@@ -164,8 +166,9 @@ public class TaskDeploymentDescriptorFactory {
}
public static SubpartitionIndexRange computeConsumedSubpartitionRange(
- IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
- int numConsumers = resultPartition.getConsumerVertexGroup().size();
+ int numConsumers,
+ IntermediateResultPartition resultPartition,
+ int consumerSubtaskIndex) {
int consumerIndex = consumerSubtaskIndex % numConsumers;
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
int numSubpartitions = resultPartition.getNumberOfSubpartitions();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 2f160f14605..ce1f18c5504 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -1399,6 +1399,7 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
final List<ConsumedPartitionGroup> releasablePartitionGroups) {
if (releasablePartitionGroups.size() > 0) {
+ final List<ResultPartitionID> releasablePartitionIds = new ArrayList<>();
// Remove the cache of ShuffleDescriptors when ConsumedPartitionGroups are released
for (ConsumedPartitionGroup releasablePartitionGroup : releasablePartitionGroups) {
@@ -1406,15 +1407,17 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
checkNotNull(
intermediateResults.get(
releasablePartitionGroup.getIntermediateDataSetID()));
+ for (IntermediateResultPartitionID partitionId : releasablePartitionGroup) {
+ IntermediateResultPartition partition =
+ totalResult.getPartitionById(partitionId);
+ partition.markPartitionGroupReleasable(releasablePartitionGroup);
+ if (partition.canBeReleased()) {
+ releasablePartitionIds.add(createResultPartitionId(partitionId));
+ }
+ }
totalResult.clearCachedInformationForPartitionGroup(releasablePartitionGroup);
}
- final List<ResultPartitionID> releasablePartitionIds =
- releasablePartitionGroups.stream()
- .flatMap(IterableUtils::toStream)
- .map(this::createResultPartitionId)
- .collect(Collectors.toList());
-
partitionTracker.stopTrackingAndReleasePartitions(releasablePartitionIds);
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
index 8efc25a913b..1a437e16030 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManager.java
@@ -30,12 +30,11 @@ import java.util.List;
import java.util.Map;
import static org.apache.flink.util.Preconditions.checkNotNull;
-import static org.apache.flink.util.Preconditions.checkState;
/** Class that manages all the connections between tasks. */
public class EdgeManager {
- private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> partitionConsumers =
+ private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>> partitionConsumers =
new HashMap<>();
private final Map<ExecutionVertexID, List<ConsumedPartitionGroup>> vertexConsumedPartitions =
@@ -50,9 +49,9 @@ public class EdgeManager {
checkNotNull(consumerVertexGroup);
- checkState(
- partitionConsumers.putIfAbsent(resultPartitionId, consumerVertexGroup) == null,
- "Currently one partition can have at most one consumer group");
+ List<ConsumerVertexGroup> groups =
+ getConsumerVertexGroupsForPartitionInternal(resultPartitionId);
+ groups.add(consumerVertexGroup);
}
public void connectVertexWithConsumedPartitionGroup(
@@ -66,14 +65,20 @@ public class EdgeManager {
consumedPartitions.add(consumedPartitionGroup);
}
+ private List<ConsumerVertexGroup> getConsumerVertexGroupsForPartitionInternal(
+ IntermediateResultPartitionID resultPartitionId) {
+ return partitionConsumers.computeIfAbsent(resultPartitionId, id -> new ArrayList<>());
+ }
+
private List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertexInternal(
ExecutionVertexID executionVertexId) {
return vertexConsumedPartitions.computeIfAbsent(executionVertexId, id -> new ArrayList<>());
}
- public ConsumerVertexGroup getConsumerVertexGroupForPartition(
+ public List<ConsumerVertexGroup> getConsumerVertexGroupsForPartition(
IntermediateResultPartitionID resultPartitionId) {
- return partitionConsumers.get(resultPartitionId);
+ return Collections.unmodifiableList(
+ getConsumerVertexGroupsForPartitionInternal(resultPartitionId));
}
public List<ConsumedPartitionGroup> getConsumedPartitionGroupsForVertex(
@@ -100,4 +105,9 @@ public class EdgeManager {
return Collections.unmodifiableList(
getConsumedPartitionGroupsByIdInternal(resultPartitionId));
}
+
+ public int getNumberOfConsumedPartitionGroupsById(
+ IntermediateResultPartitionID resultPartitionId) {
+ return getConsumedPartitionGroupsByIdInternal(resultPartitionId).size();
+ }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
index 3aa6e59de0d..8ac55b7b7a2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtil.java
@@ -89,7 +89,7 @@ public class EdgeManagerBuildUtil {
.collect(Collectors.toList());
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
- consumedPartitions, intermediateResult);
+ taskVertices.length, consumedPartitions, intermediateResult);
for (ExecutionVertex ev : taskVertices) {
ev.addConsumedPartitionGroup(consumedPartitionGroup);
}
@@ -122,7 +122,9 @@ public class EdgeManagerBuildUtil {
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
- partition.getPartitionId(), intermediateResult);
+ consumerVertexGroup.size(),
+ partition.getPartitionId(),
+ intermediateResult);
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
}
} else if (sourceCount > targetCount) {
@@ -147,20 +149,19 @@ public class EdgeManagerBuildUtil {
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
- consumedPartitions, intermediateResult);
+ consumerVertexGroup.size(), consumedPartitions, intermediateResult);
executionVertex.addConsumedPartitionGroup(consumedPartitionGroup);
}
} else {
for (int partitionNum = 0; partitionNum < sourceCount; partitionNum++) {
+ int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
+ int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
IntermediateResultPartition partition =
intermediateResult.getPartitions()[partitionNum];
ConsumedPartitionGroup consumedPartitionGroup =
createAndRegisterConsumedPartitionGroupToEdgeManager(
- partition.getPartitionId(), intermediateResult);
-
- int start = (partitionNum * targetCount + sourceCount - 1) / sourceCount;
- int end = ((partitionNum + 1) * targetCount + sourceCount - 1) / sourceCount;
+ end - start, partition.getPartitionId(), intermediateResult);
List<ExecutionVertexID> consumers = new ArrayList<>(end - start);
@@ -179,21 +180,23 @@ public class EdgeManagerBuildUtil {
}
private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+ int numConsumers,
IntermediateResultPartitionID consumedPartitionId,
IntermediateResult intermediateResult) {
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromSinglePartition(
- consumedPartitionId, intermediateResult.getResultType());
+ numConsumers, consumedPartitionId, intermediateResult.getResultType());
registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
return consumedPartitionGroup;
}
private static ConsumedPartitionGroup createAndRegisterConsumedPartitionGroupToEdgeManager(
+ int numConsumers,
List<IntermediateResultPartitionID> consumedPartitions,
IntermediateResult intermediateResult) {
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(
- consumedPartitions, intermediateResult.getResultType());
+ numConsumers, consumedPartitions, intermediateResult.getResultType());
registerConsumedPartitionGroupToEdgeManager(consumedPartitionGroup, intermediateResult);
return consumedPartitionGroup;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
index 657effb66f6..f6de9059869 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/Execution.java
@@ -65,6 +65,7 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
@@ -497,10 +498,7 @@ public class Execution
}
private static int getPartitionMaxParallelism(IntermediateResultPartition partition) {
- return partition
- .getIntermediateResult()
- .getConsumerExecutionJobVertex()
- .getMaxParallelism();
+ return partition.getIntermediateResult().getConsumersMaxParallelism();
}
/**
@@ -718,31 +716,40 @@ public class Execution
}
private void updatePartitionConsumers(final IntermediateResultPartition partition) {
- final Optional<ConsumerVertexGroup> consumerVertexGroup =
- partition.getConsumerVertexGroupOptional();
- if (!consumerVertexGroup.isPresent()) {
+ final List<ConsumerVertexGroup> consumerVertexGroups = partition.getConsumerVertexGroups();
+ if (consumerVertexGroups.isEmpty()) {
return;
}
- for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
- final ExecutionVertex consumerVertex =
- vertex.getExecutionGraphAccessor().getExecutionVertexOrThrow(consumerVertexId);
- final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
- final ExecutionState consumerState = consumer.getState();
-
- // ----------------------------------------------------------------
- // Consumer is recovering or running => send update message now
- // Consumer is deploying => cache the partition info which would be
- // sent after switching to running
- // ----------------------------------------------------------------
- if (consumerState == DEPLOYING
- || consumerState == RUNNING
- || consumerState == INITIALIZING) {
- final PartitionInfo partitionInfo = createPartitionInfo(partition);
-
- if (consumerState == DEPLOYING) {
- consumerVertex.cachePartitionInfo(partitionInfo);
- } else {
- consumer.sendUpdatePartitionInfoRpcCall(Collections.singleton(partitionInfo));
+ final Set<ExecutionVertexID> updatedVertices = new HashSet<>();
+ for (ConsumerVertexGroup consumerVertexGroup : consumerVertexGroups) {
+ for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+ if (updatedVertices.contains(consumerVertexId)) {
+ continue;
+ }
+
+ final ExecutionVertex consumerVertex =
+ vertex.getExecutionGraphAccessor()
+ .getExecutionVertexOrThrow(consumerVertexId);
+ final Execution consumer = consumerVertex.getCurrentExecutionAttempt();
+ final ExecutionState consumerState = consumer.getState();
+
+ // ----------------------------------------------------------------
+ // Consumer is recovering or running => send update message now
+ // Consumer is deploying => cache the partition info which would be
+ // sent after switching to running
+ // ----------------------------------------------------------------
+ if (consumerState == DEPLOYING
+ || consumerState == RUNNING
+ || consumerState == INITIALIZING) {
+ final PartitionInfo partitionInfo = createPartitionInfo(partition);
+ updatedVertices.add(consumerVertexId);
+
+ if (consumerState == DEPLOYING) {
+ consumerVertex.cachePartitionInfo(partitionInfo);
+ } else {
+ consumer.sendUpdatePartitionInfoRpcCall(
+ Collections.singleton(partitionInfo));
+ }
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index 4b666b5742b..fd8be042f6f 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -32,12 +32,15 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import static org.apache.flink.util.Preconditions.checkArgument;
import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
public class IntermediateResult {
@@ -68,6 +71,9 @@ public class IntermediateResult {
private final Map<ConsumedPartitionGroup, MaybeOffloaded<ShuffleDescriptor[]>>
shuffleDescriptorCache;
+ /** All consumer job vertex ids of this dataset. */
+ private final List<JobVertexID> consumerVertices = new ArrayList<>();
+
public IntermediateResult(
IntermediateDataSet intermediateDataSet,
ExecutionJobVertex producer,
@@ -95,6 +101,10 @@ public class IntermediateResult {
this.resultType = checkNotNull(resultType);
this.shuffleDescriptorCache = new HashMap<>();
+
+ intermediateDataSet
+ .getConsumers()
+ .forEach(jobEdge -> consumerVertices.add(jobEdge.getTarget().getID()));
}
public void setPartition(int partitionNumber, IntermediateResultPartition partition) {
@@ -124,6 +134,10 @@ public class IntermediateResult {
return partitions;
}
+ public List<JobVertexID> getConsumerVertices() {
+ return consumerVertices;
+ }
+
/**
* Returns the partition with the given ID.
*
@@ -162,20 +176,60 @@ public class IntermediateResult {
return numParallelProducers;
}
- ExecutionJobVertex getConsumerExecutionJobVertex() {
- final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
- final JobVertexID consumerJobVertexId = consumer.getTarget().getID();
- return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId));
+ /**
+ * Currently, this method is only used to compute the maximum number of consumers. For dynamic
+ * graph, it should be called before adaptively deciding the downstream consumer parallelism.
+ */
+ int getConsumersParallelism() {
+ List<JobEdge> consumers = intermediateDataSet.getConsumers();
+ checkState(!consumers.isEmpty());
+
+ InternalExecutionGraphAccessor graph = getProducer().getGraph();
+ int consumersParallelism =
+ graph.getJobVertex(consumers.get(0).getTarget().getID()).getParallelism();
+ if (consumers.size() == 1) {
+ return consumersParallelism;
+ }
+
+ // sanity check, all consumer vertices must have the same parallelism:
+ // 1. for vertices that are not assigned a parallelism initially (for example, dynamic
+ // graph), the parallelisms will all be -1 (parallelism not decided yet)
+ // 2. for vertices that are initially assigned a parallelism, the parallelisms must be the
+ // same, which is guaranteed at compilation phase
+ for (JobVertexID jobVertexID : consumerVertices) {
+ checkState(
+ consumersParallelism == graph.getJobVertex(jobVertexID).getParallelism(),
+ "Consumers must have the same parallelism.");
+ }
+ return consumersParallelism;
+ }
+
+ int getConsumersMaxParallelism() {
+ List<JobEdge> consumers = intermediateDataSet.getConsumers();
+ checkState(!consumers.isEmpty());
+
+ InternalExecutionGraphAccessor graph = getProducer().getGraph();
+ int consumersMaxParallelism =
+ graph.getJobVertex(consumers.get(0).getTarget().getID()).getMaxParallelism();
+ if (consumers.size() == 1) {
+ return consumersMaxParallelism;
+ }
+
+ // sanity check, all consumer vertices must have the same max parallelism
+ for (JobVertexID jobVertexID : consumerVertices) {
+ checkState(
+ consumersMaxParallelism == graph.getJobVertex(jobVertexID).getMaxParallelism(),
+ "Consumers must have the same max parallelism.");
+ }
+ return consumersMaxParallelism;
}
public DistributionPattern getConsumingDistributionPattern() {
- final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
- return consumer.getDistributionPattern();
+ return intermediateDataSet.getDistributionPattern();
}
public boolean isBroadcast() {
- final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
- return consumer.isBroadcast();
+ return intermediateDataSet.isBroadcast();
}
public int getConnectionIndex() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
index 5aef1b07471..9b9c176a3d9 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartition.java
@@ -21,11 +21,13 @@ package org.apache.flink.runtime.executiongraph;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ConsumerVertexGroup;
+import java.util.HashSet;
import java.util.List;
-import java.util.Optional;
+import java.util.Set;
import static org.apache.flink.util.Preconditions.checkState;
@@ -47,6 +49,12 @@ public class IntermediateResultPartition {
/** Whether this partition has produced some data. */
private boolean hasDataProduced = false;
+ /**
+ * Releasable {@link ConsumedPartitionGroup}s for this result partition. This result partition
+ * can be released if all {@link ConsumedPartitionGroup}s are releasable.
+ */
+ private final Set<ConsumedPartitionGroup> releasablePartitionGroups = new HashSet<>();
+
public IntermediateResultPartition(
IntermediateResult totalResult,
ExecutionVertex producer,
@@ -58,6 +66,25 @@ public class IntermediateResultPartition {
this.edgeManager = edgeManager;
}
+ public void markPartitionGroupReleasable(ConsumedPartitionGroup partitionGroup) {
+ releasablePartitionGroups.add(partitionGroup);
+ }
+
+ public boolean canBeReleased() {
+ if (releasablePartitionGroups.size()
+ != edgeManager.getNumberOfConsumedPartitionGroupsById(partitionId)) {
+ return false;
+ }
+ for (JobVertexID jobVertexId : totalResult.getConsumerVertices()) {
+ // for dynamic graph, if any consumer vertex is still not initialized, this result
+ // partition can not be released
+ if (!producer.getExecutionGraphAccessor().getJobVertex(jobVertexId).isInitialized()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
public ExecutionVertex getProducer() {
return producer;
}
@@ -78,15 +105,8 @@ public class IntermediateResultPartition {
return totalResult.getResultType();
}
- public ConsumerVertexGroup getConsumerVertexGroup() {
- Optional<ConsumerVertexGroup> consumerVertexGroup = getConsumerVertexGroupOptional();
- checkState(consumerVertexGroup.isPresent());
- return consumerVertexGroup.get();
- }
-
- public Optional<ConsumerVertexGroup> getConsumerVertexGroupOptional() {
- return Optional.ofNullable(
- getEdgeManager().getConsumerVertexGroupForPartition(partitionId));
+ public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+ return getEdgeManager().getConsumerVertexGroupsForPartition(partitionId);
}
public List<ConsumedPartitionGroup> getConsumedPartitionGroups() {
@@ -106,12 +126,13 @@ public class IntermediateResultPartition {
private int computeNumberOfSubpartitions() {
if (!getProducer().getExecutionGraphAccessor().isDynamic()) {
- ConsumerVertexGroup consumerVertexGroup = getConsumerVertexGroup();
- checkState(consumerVertexGroup.size() > 0);
+ List<ConsumerVertexGroup> consumerVertexGroups = getConsumerVertexGroups();
+ checkState(!consumerVertexGroups.isEmpty());
// The produced data is partitioned among a number of subpartitions, one for each
- // consuming sub task.
- return consumerVertexGroup.size();
+ // consuming sub task. All vertex groups must have the same number of consumers
+ // for non-dynamic graph.
+ return consumerVertexGroups.get(0).size();
} else {
if (totalResult.isBroadcast()) {
// for dynamic graph and broadcast result, we only produced one subpartition,
@@ -124,18 +145,16 @@ public class IntermediateResultPartition {
}
private int computeNumberOfMaxPossiblePartitionConsumers() {
- final ExecutionJobVertex consumerJobVertex =
- getIntermediateResult().getConsumerExecutionJobVertex();
final DistributionPattern distributionPattern =
getIntermediateResult().getConsumingDistributionPattern();
// decide the max possible consumer job vertex parallelism
- int maxConsumerJobVertexParallelism = consumerJobVertex.getParallelism();
+ int maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersParallelism();
if (maxConsumerJobVertexParallelism <= 0) {
+ maxConsumerJobVertexParallelism = getIntermediateResult().getConsumersMaxParallelism();
checkState(
- consumerJobVertex.getMaxParallelism() > 0,
+ maxConsumerJobVertexParallelism > 0,
"Neither the parallelism nor the max parallelism of a job vertex is set");
- maxConsumerJobVertexParallelism = consumerJobVertex.getMaxParallelism();
}
// compute number of subpartitions according to the distribution pattern
@@ -163,6 +182,7 @@ public class IntermediateResultPartition {
consumedPartitionGroup.partitionUnfinished();
}
}
+ releasablePartitionGroups.clear();
hasDataProduced = false;
for (ConsumedPartitionGroup consumedPartitionGroup : getConsumedPartitionGroups()) {
totalResult.clearCachedInformationForPartitionGroup(consumedPartitionGroup);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
index 6e7181ffd03..39d7fe72547 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/RestartPipelinedRegionFailoverStrategy.java
@@ -230,12 +230,12 @@ public class RestartPipelinedRegionFailoverStrategy implements FailoverStrategy
for (SchedulingExecutionVertex vertex : regionToRestart.getVertices()) {
for (SchedulingResultPartition producedPartition : vertex.getProducedResults()) {
- final Optional<ConsumerVertexGroup> consumerVertexGroup =
- producedPartition.getConsumerVertexGroup();
- if (consumerVertexGroup.isPresent()
- && !visitedConsumerVertexGroups.contains(consumerVertexGroup.get())) {
- visitedConsumerVertexGroups.add(consumerVertexGroup.get());
- consumerVertexGroupsToVisit.add(consumerVertexGroup.get());
+ for (ConsumerVertexGroup consumerVertexGroup :
+ producedPartition.getConsumerVertexGroups()) {
+ if (!visitedConsumerVertexGroups.contains(consumerVertexGroup)) {
+ visitedConsumerVertexGroups.add(consumerVertexGroup);
+ consumerVertexGroupsToVisit.add(consumerVertexGroup);
+ }
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
index de353be7ba8..6428284cd49 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/failover/flip1/SchedulingPipelinedRegionComputeUtil.java
@@ -32,7 +32,6 @@ import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
@@ -116,23 +115,20 @@ public final class SchedulingPipelinedRegionComputeUtil {
if (producedResult.getResultType().mustBePipelinedConsumed()) {
continue;
}
- final Optional<ConsumerVertexGroup> consumerVertexGroup =
- producedResult.getConsumerVertexGroup();
- if (!consumerVertexGroup.isPresent()) {
- continue;
- }
-
- for (ExecutionVertexID consumerVertexId : consumerVertexGroup.get()) {
- SchedulingExecutionVertex consumerVertex =
- executionVertexRetriever.apply(consumerVertexId);
- // Skip the ConsumerVertexGroup if its vertices are outside current
- // regions and cannot be merged
- if (!vertexToRegion.containsKey(consumerVertex)) {
- break;
- }
- if (!currentRegion.contains(consumerVertex)) {
- currentRegionOutEdges.add(
- regionIndices.get(vertexToRegion.get(consumerVertex)));
+ for (ConsumerVertexGroup consumerVertexGroup :
+ producedResult.getConsumerVertexGroups()) {
+ for (ExecutionVertexID consumerVertexId : consumerVertexGroup) {
+ SchedulingExecutionVertex consumerVertex =
+ executionVertexRetriever.apply(consumerVertexId);
+ // Skip the ConsumerVertexGroup if its vertices are outside current
+ // regions and cannot be merged
+ if (!vertexToRegion.containsKey(consumerVertex)) {
+ break;
+ }
+ if (!currentRegion.contains(consumerVertex)) {
+ currentRegionOutEdges.add(
+ regionIndices.get(vertexToRegion.get(consumerVertex)));
+ }
}
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
index aa45dd4b56f..2d55cf4c639 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironment.java
@@ -56,6 +56,7 @@ import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
@@ -93,7 +94,7 @@ public class NettyShuffleEnvironment
private final FileChannelManager fileChannelManager;
- private final Map<InputGateID, SingleInputGate> inputGatesById;
+ private final Map<InputGateID, Set<SingleInputGate>> inputGatesById;
private final ResultPartitionFactory resultPartitionFactory;
@@ -169,7 +170,7 @@ public class NettyShuffleEnvironment
}
@VisibleForTesting
- public Optional<InputGate> getInputGate(InputGateID id) {
+ public Optional<Collection<SingleInputGate>> getInputGate(InputGateID id) {
return Optional.ofNullable(inputGatesById.get(id));
}
@@ -260,8 +261,24 @@ public class NettyShuffleEnvironment
InputGateID id =
new InputGateID(
igdd.getConsumedResultId(), ownerContext.getExecutionAttemptID());
- inputGatesById.put(id, inputGate);
- inputGate.getCloseFuture().thenRun(() -> inputGatesById.remove(id));
+ Set<SingleInputGate> inputGateSet =
+ inputGatesById.computeIfAbsent(
+ id, ignored -> ConcurrentHashMap.newKeySet());
+ inputGateSet.add(inputGate);
+ inputGatesById.put(id, inputGateSet);
+ inputGate
+ .getCloseFuture()
+ .thenRun(
+ () ->
+ inputGatesById.computeIfPresent(
+ id,
+ (key, value) -> {
+ value.remove(inputGate);
+ if (value.isEmpty()) {
+ return null;
+ }
+ return value;
+ }));
inputGates[gateIndex] = inputGate;
}
@@ -297,17 +314,20 @@ public class NettyShuffleEnvironment
IntermediateDataSetID intermediateResultPartitionID =
partitionInfo.getIntermediateDataSetID();
InputGateID id = new InputGateID(intermediateResultPartitionID, consumerID);
- SingleInputGate inputGate = inputGatesById.get(id);
- if (inputGate == null) {
+ Set<SingleInputGate> inputGates = inputGatesById.get(id);
+ if (inputGates == null || inputGates.isEmpty()) {
return false;
}
+
ShuffleDescriptor shuffleDescriptor = partitionInfo.getShuffleDescriptor();
checkArgument(
shuffleDescriptor instanceof NettyShuffleDescriptor,
"Tried to update unknown channel with unknown ShuffleDescriptor %s.",
shuffleDescriptor.getClass().getName());
- inputGate.updateInputChannel(
- taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+ for (SingleInputGate inputGate : inputGates) {
+ inputGate.updateInputChannel(
+ taskExecutorResourceId, (NettyShuffleDescriptor) shuffleDescriptor);
+ }
return true;
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
index d6b3abc3b52..2aad2613996 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/IntermediateDataSet.java
@@ -20,7 +20,8 @@ package org.apache.flink.runtime.jobgraph;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
-import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
@@ -39,11 +40,16 @@ public class IntermediateDataSet implements java.io.Serializable {
private final JobVertex producer; // the operation that produced this data set
- @Nullable private JobEdge consumer;
+ // All consumers must have the same partitioner and parallelism
+ private final List<JobEdge> consumers = new ArrayList<>();
// The type of partition to use at runtime
private final ResultPartitionType resultType;
+ private DistributionPattern distributionPattern;
+
+ private boolean isBroadcast;
+
// --------------------------------------------------------------------------------------------
public IntermediateDataSet(
@@ -63,9 +69,16 @@ public class IntermediateDataSet implements java.io.Serializable {
return producer;
}
- @Nullable
- public JobEdge getConsumer() {
- return consumer;
+ public List<JobEdge> getConsumers() {
+ return this.consumers;
+ }
+
+ public boolean isBroadcast() {
+ return isBroadcast;
+ }
+
+ public DistributionPattern getDistributionPattern() {
+ return distributionPattern;
}
public ResultPartitionType getResultType() {
@@ -75,10 +88,19 @@ public class IntermediateDataSet implements java.io.Serializable {
// --------------------------------------------------------------------------------------------
public void addConsumer(JobEdge edge) {
- checkState(
- this.consumer == null,
- "Currently one IntermediateDataSet can have at most one consumer.");
- this.consumer = edge;
+ // sanity check
+ checkState(id.equals(edge.getSourceId()), "Incompatible dataset id.");
+
+ if (consumers.isEmpty()) {
+ distributionPattern = edge.getDistributionPattern();
+ isBroadcast = edge.isBroadcast();
+ } else {
+ checkState(
+ distributionPattern == edge.getDistributionPattern(),
+ "Incompatible distribution pattern.");
+ checkState(isBroadcast == edge.isBroadcast(), "Incompatible broadcast type.");
+ }
+ consumers.add(edge);
}
// --------------------------------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
index 9772ff4dbb1..4649303c144 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobEdge.java
@@ -44,10 +44,7 @@ public class JobEdge implements java.io.Serializable {
private SubtaskStateMapper upstreamSubtaskStateMapper = SubtaskStateMapper.ROUND_ROBIN;
/** The data set at the source of the edge, may be null if the edge is not yet connected. */
- private IntermediateDataSet source;
-
- /** The id of the source intermediate data set. */
- private IntermediateDataSetID sourceId;
+ private final IntermediateDataSet source;
/**
* Optional name for the data shipping strategy (forward, partition hash, rebalance, ...), to be
@@ -55,7 +52,7 @@ public class JobEdge implements java.io.Serializable {
*/
private String shipStrategyName;
- private boolean isBroadcast;
+ private final boolean isBroadcast;
private boolean isForward;
@@ -74,36 +71,20 @@ public class JobEdge implements java.io.Serializable {
* @param source The data set that is at the source of this edge.
* @param target The operation that is at the target of this edge.
* @param distributionPattern The pattern that defines how the connection behaves in parallel.
+ * @param isBroadcast Whether the source broadcasts data to the target.
*/
public JobEdge(
- IntermediateDataSet source, JobVertex target, DistributionPattern distributionPattern) {
+ IntermediateDataSet source,
+ JobVertex target,
+ DistributionPattern distributionPattern,
+ boolean isBroadcast) {
if (source == null || target == null || distributionPattern == null) {
throw new NullPointerException();
}
this.target = target;
this.distributionPattern = distributionPattern;
this.source = source;
- this.sourceId = source.getId();
- }
-
- /**
- * Constructs a new job edge that refers to an intermediate result via the Id, rather than
- * directly through the intermediate data set structure.
- *
- * @param sourceId The id of the data set that is at the source of this edge.
- * @param target The operation that is at the target of this edge.
- * @param distributionPattern The pattern that defines how the connection behaves in parallel.
- */
- public JobEdge(
- IntermediateDataSetID sourceId,
- JobVertex target,
- DistributionPattern distributionPattern) {
- if (sourceId == null || target == null || distributionPattern == null) {
- throw new NullPointerException();
- }
- this.target = target;
- this.distributionPattern = distributionPattern;
- this.sourceId = sourceId;
+ this.isBroadcast = isBroadcast;
}
/**
@@ -140,11 +121,7 @@ public class JobEdge implements java.io.Serializable {
* @return The ID of the consumed data set.
*/
public IntermediateDataSetID getSourceId() {
- return sourceId;
- }
-
- public boolean isIdReference() {
- return this.source == null;
+ return source.getId();
}
// --------------------------------------------------------------------------------------------
@@ -173,11 +150,6 @@ public class JobEdge implements java.io.Serializable {
return isBroadcast;
}
- /** Sets whether the edge is broadcast edge. */
- public void setBroadcast(boolean broadcast) {
- isBroadcast = broadcast;
- }
-
/** Gets whether the edge is forward edge. */
public boolean isForward() {
return isForward;
@@ -268,6 +240,6 @@ public class JobEdge implements java.io.Serializable {
@Override
public String toString() {
- return String.format("%s --> %s [%s]", sourceId, target, distributionPattern.name());
+ return String.format("%s --> %s [%s]", source.getId(), target, distributionPattern.name());
}
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
index e8baef4d8e7..bb821adda7b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobGraph.java
@@ -458,10 +458,9 @@ public class JobGraph implements Serializable {
private void addNodesThatHaveNoNewPredecessors(
JobVertex start, List<JobVertex> target, Set<JobVertex> remaining) {
- // forward traverse over all produced data sets
+ // forward traverse over all produced data sets and all their consumers
for (IntermediateDataSet dataSet : start.getProducedDataSets()) {
- JobEdge edge = dataSet.getConsumer();
- if (edge != null) {
+ for (JobEdge edge : dataSet.getConsumers()) {
// a vertex can be added, if it has no predecessors that are still in the
// 'remaining' set
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
index 36d1a1e5487..717bd6d4376 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/jobgraph/JobVertex.java
@@ -36,7 +36,9 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -72,8 +74,8 @@ public class JobVertex implements java.io.Serializable {
*/
private final List<OperatorIDPair> operatorIDs;
- /** List of produced data sets, one per writer. */
- private final ArrayList<IntermediateDataSet> results = new ArrayList<>();
+ /** Produced data sets, one per writer. */
+ private final Map<IntermediateDataSetID, IntermediateDataSet> results = new LinkedHashMap<>();
/** List of edges with incoming data. One per Reader. */
private final ArrayList<JobEdge> inputs = new ArrayList<>();
@@ -374,7 +376,7 @@ public class JobVertex implements java.io.Serializable {
}
public List<IntermediateDataSet> getProducedDataSets() {
- return this.results;
+ return new ArrayList<>(results.values());
}
public List<JobEdge> getInputs() {
@@ -481,30 +483,37 @@ public class JobVertex implements java.io.Serializable {
}
// --------------------------------------------------------------------------------------------
- public IntermediateDataSet createAndAddResultDataSet(
+ public IntermediateDataSet getOrCreateResultDataSet(
IntermediateDataSetID id, ResultPartitionType partitionType) {
-
- IntermediateDataSet result = new IntermediateDataSet(id, partitionType, this);
- this.results.add(result);
- return result;
+ return this.results.computeIfAbsent(
+ id, key -> new IntermediateDataSet(id, partitionType, this));
}
public JobEdge connectNewDataSetAsInput(
JobVertex input, DistributionPattern distPattern, ResultPartitionType partitionType) {
+ return connectNewDataSetAsInput(input, distPattern, partitionType, false);
+ }
+
+ public JobEdge connectNewDataSetAsInput(
+ JobVertex input,
+ DistributionPattern distPattern,
+ ResultPartitionType partitionType,
+ boolean isBroadcast) {
return connectNewDataSetAsInput(
- input, distPattern, partitionType, new IntermediateDataSetID());
+ input, distPattern, partitionType, new IntermediateDataSetID(), isBroadcast);
}
public JobEdge connectNewDataSetAsInput(
JobVertex input,
DistributionPattern distPattern,
ResultPartitionType partitionType,
- IntermediateDataSetID intermediateDataSetId) {
+ IntermediateDataSetID intermediateDataSetId,
+ boolean isBroadcast) {
IntermediateDataSet dataSet =
- input.createAndAddResultDataSet(intermediateDataSetId, partitionType);
+ input.getOrCreateResultDataSet(intermediateDataSetId, partitionType);
- JobEdge edge = new JobEdge(dataSet, this, distPattern);
+ JobEdge edge = new JobEdge(dataSet, this, distPattern, isBroadcast);
this.inputs.add(edge);
dataSet.addConsumer(edge);
return edge;
@@ -525,13 +534,7 @@ public class JobVertex implements java.io.Serializable {
}
public boolean hasNoConnectedInputs() {
- for (JobEdge edge : inputs) {
- if (!edge.isIdReference()) {
- return false;
- }
- }
-
- return true;
+ return inputs.isEmpty();
}
public void markContainsSources() {
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index 0fac885dd3f..1506152fa38 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -45,7 +45,6 @@ import java.util.List;
import java.util.Map;
import java.util.function.Function;
-import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;
/**
@@ -143,15 +142,22 @@ public class SsgNetworkMemoryCalculationUtils {
Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
- for (int i = 0; i < producedDataSets.size(); i++) {
- IntermediateDataSet producedDataSet = producedDataSets.get(i);
- JobEdge outputEdge = checkNotNull(producedDataSet.getConsumer());
- ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
- int maxNum =
- EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
- ejv.getParallelism(),
- consumerJobVertex.getParallelism(),
- outputEdge.getDistributionPattern());
+ checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph.");
+ for (IntermediateDataSet producedDataSet : producedDataSets) {
+ int maxNum = 0;
+ List<JobEdge> outputEdges = producedDataSet.getConsumers();
+
+ if (!outputEdges.isEmpty()) {
+ // for non-dynamic graph, the consumer vertices' parallelisms and distribution
+ // patterns must be the same
+ JobEdge outputEdge = outputEdges.get(0);
+ ExecutionJobVertex consumerJobVertex = ejvs.apply(outputEdge.getTarget().getID());
+ maxNum =
+ EdgeManagerBuildUtil.computeMaxEdgesToTargetExecutionVertex(
+ ejv.getParallelism(),
+ consumerJobVertex.getParallelism(),
+ outputEdge.getDistributionPattern());
+ }
ret.put(producedDataSet.getId(), maxNum);
}
@@ -177,7 +183,9 @@ public class SsgNetworkMemoryCalculationUtils {
ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
SubpartitionIndexRange subpartitionIndexRange =
TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
- resultPartition, vertex.getParallelSubtaskIndex());
+ partitionGroup.getNumConsumers(),
+ resultPartition,
+ vertex.getParallelSubtaskIndex());
ret.merge(
partitionGroup.getIntermediateDataSetID(),
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
index 486ae28b264..f2fd573e613 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopology.java
@@ -262,7 +262,7 @@ public class DefaultExecutionTopology implements SchedulingTopology {
List<DefaultResultPartition> producedPartitions =
generateProducedSchedulingResultPartition(
vertex.getProducedPartitions(),
- edgeManager::getConsumerVertexGroupForPartition);
+ edgeManager::getConsumerVertexGroupsForPartition);
producedPartitions.forEach(
partition -> resultPartitionsById.put(partition.getId(), partition));
@@ -285,8 +285,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
private static List<DefaultResultPartition> generateProducedSchedulingResultPartition(
Map<IntermediateResultPartitionID, IntermediateResultPartition>
producedIntermediatePartitions,
- Function<IntermediateResultPartitionID, ConsumerVertexGroup>
- partitionConsumerVertexGroupRetriever) {
+ Function<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+ partitionConsumerVertexGroupsRetriever) {
List<DefaultResultPartition> producedSchedulingPartitions =
new ArrayList<>(producedIntermediatePartitions.size());
@@ -305,8 +305,8 @@ public class DefaultExecutionTopology implements SchedulingTopology {
? ResultPartitionState.CONSUMABLE
: ResultPartitionState.CREATED,
() ->
- partitionConsumerVertexGroupRetriever.apply(
- irp.getPartitionId()),
+ partitionConsumerVertexGroupsRetriever
+ .apply(irp.getPartitionId()),
irp::getConsumedPartitionGroups)));
return producedSchedulingPartitions;
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
index b46c54bd58e..2ae54af4304 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartition.java
@@ -27,7 +27,6 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
import org.apache.flink.runtime.scheduler.strategy.SchedulingResultPartition;
import java.util.List;
-import java.util.Optional;
import java.util.function.Supplier;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -45,7 +44,7 @@ class DefaultResultPartition implements SchedulingResultPartition {
private DefaultExecutionVertex producer;
- private final Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier;
+ private final Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier;
private final Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier;
@@ -54,13 +53,13 @@ class DefaultResultPartition implements SchedulingResultPartition {
IntermediateDataSetID intermediateDataSetId,
ResultPartitionType partitionType,
Supplier<ResultPartitionState> resultPartitionStateSupplier,
- Supplier<ConsumerVertexGroup> consumerVertexGroupSupplier,
+ Supplier<List<ConsumerVertexGroup>> consumerVertexGroupsSupplier,
Supplier<List<ConsumedPartitionGroup>> consumerPartitionGroupSupplier) {
this.resultPartitionId = checkNotNull(partitionId);
this.intermediateDataSetId = checkNotNull(intermediateDataSetId);
this.partitionType = checkNotNull(partitionType);
this.resultPartitionStateSupplier = checkNotNull(resultPartitionStateSupplier);
- this.consumerVertexGroupSupplier = checkNotNull(consumerVertexGroupSupplier);
+ this.consumerVertexGroupsSupplier = checkNotNull(consumerVertexGroupsSupplier);
this.consumerPartitionGroupSupplier = checkNotNull(consumerPartitionGroupSupplier);
}
@@ -90,8 +89,8 @@ class DefaultResultPartition implements SchedulingResultPartition {
}
@Override
- public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
- return Optional.ofNullable(consumerVertexGroupSupplier.get());
+ public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+ return checkNotNull(consumerVertexGroupsSupplier.get());
}
@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
index 6e4672c48e9..2f5be93ca51 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/ConsumedPartitionGroup.java
@@ -42,12 +42,17 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
private final ResultPartitionType resultPartitionType;
+ /** Number of consumer tasks in the corresponding {@link ConsumerVertexGroup}. */
+ private final int numConsumers;
+
private ConsumedPartitionGroup(
+ int numConsumers,
List<IntermediateResultPartitionID> resultPartitions,
ResultPartitionType resultPartitionType) {
checkArgument(
resultPartitions.size() > 0,
"The size of result partitions in the ConsumedPartitionGroup should be larger than 0.");
+ this.numConsumers = numConsumers;
this.intermediateDataSetID = resultPartitions.get(0).getIntermediateDataSetID();
this.resultPartitionType = Preconditions.checkNotNull(resultPartitionType);
@@ -63,16 +68,18 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
}
public static ConsumedPartitionGroup fromMultiplePartitions(
+ int numConsumers,
List<IntermediateResultPartitionID> resultPartitions,
ResultPartitionType resultPartitionType) {
- return new ConsumedPartitionGroup(resultPartitions, resultPartitionType);
+ return new ConsumedPartitionGroup(numConsumers, resultPartitions, resultPartitionType);
}
public static ConsumedPartitionGroup fromSinglePartition(
+ int numConsumers,
IntermediateResultPartitionID resultPartition,
ResultPartitionType resultPartitionType) {
return new ConsumedPartitionGroup(
- Collections.singletonList(resultPartition), resultPartitionType);
+ numConsumers, Collections.singletonList(resultPartition), resultPartitionType);
}
@Override
@@ -88,6 +95,14 @@ public class ConsumedPartitionGroup implements Iterable<IntermediateResultPartit
return resultPartitions.isEmpty();
}
+ /**
+ * In dynamic graph cases, the number of consumers of ConsumedPartitionGroup can be different
+ * even if they contain the same IntermediateResultPartition.
+ */
+ public int getNumConsumers() {
+ return numConsumers;
+ }
+
public IntermediateResultPartitionID getFirst() {
return iterator().next();
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
index b52150c96cd..86cafebac6b 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/SchedulingResultPartition.java
@@ -24,7 +24,6 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.topology.Result;
import java.util.List;
-import java.util.Optional;
/** Representation of {@link IntermediateResultPartition}. */
public interface SchedulingResultPartition
@@ -49,11 +48,11 @@ public interface SchedulingResultPartition
ResultPartitionState getState();
/**
- * Gets the {@link ConsumerVertexGroup}.
+ * Gets the {@link ConsumerVertexGroup}s.
*
- * @return {@link ConsumerVertexGroup} if consumers exists, otherwise {@link Optional#empty()}.
+ * @return list of {@link ConsumerVertexGroup}s
*/
- Optional<ConsumerVertexGroup> getConsumerVertexGroup();
+ List<ConsumerVertexGroup> getConsumerVertexGroups();
/**
* Gets the {@link ConsumedPartitionGroup}s this partition belongs to.
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
index 4f7d01e4e8f..4d217a9eb97 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/strategy/VertexwiseSchedulingStrategy.java
@@ -24,12 +24,12 @@ import org.apache.flink.runtime.scheduler.SchedulerOperations;
import org.apache.flink.runtime.scheduler.SchedulingTopologyListener;
import org.apache.flink.util.IterableUtils;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@@ -85,11 +85,9 @@ public class VertexwiseSchedulingStrategy
Set<ExecutionVertexID> consumerVertices =
IterableUtils.toStream(executionVertex.getProducedResults())
- .map(SchedulingResultPartition::getConsumerVertexGroup)
- .filter(Optional::isPresent)
- .flatMap(
- consumerVertexGroup ->
- IterableUtils.toStream(consumerVertexGroup.get()))
+ .map(SchedulingResultPartition::getConsumerVertexGroups)
+ .flatMap(Collection::stream)
+ .flatMap(IterableUtils::toStream)
.collect(Collectors.toSet());
maybeScheduleVertices(consumerVertices);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
new file mode 100644
index 00000000000..af1c543c509
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/BlockingResultPartitionReleaseTest.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.executiongraph;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.blob.TestingBlobWriter;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
+import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
+import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorExtension;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.RegisterExtension;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ScheduledExecutorService;
+
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
+import static org.apache.flink.runtime.scheduler.SchedulerTestingUtils.createSchedulerAndDeploy;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests that blocking result partitions are properly released. */
+class BlockingResultPartitionReleaseTest {
+
+ @RegisterExtension
+ static final TestExecutorExtension<ScheduledExecutorService> EXECUTOR_RESOURCE =
+ TestingUtils.defaultExecutorExtension();
+
+ private ScheduledExecutorService scheduledExecutorService;
+ private ComponentMainThreadExecutor mainThreadExecutor;
+ private ManuallyTriggeredScheduledExecutorService ioExecutor;
+
+ @BeforeEach
+ void setup() {
+ scheduledExecutorService = Executors.newSingleThreadScheduledExecutor();
+ mainThreadExecutor =
+ ComponentMainThreadExecutorServiceAdapter.forSingleThreadExecutor(
+ scheduledExecutorService);
+ ioExecutor = new ManuallyTriggeredScheduledExecutorService();
+ }
+
+ @AfterEach
+ void teardown() {
+ if (scheduledExecutorService != null) {
+ scheduledExecutorService.shutdownNow();
+ }
+ }
+
+ @Test
+ void testMultipleConsumersForAdaptiveBatchScheduler() throws Exception {
+ testResultPartitionConsumedByMultiConsumers(true);
+ }
+
+ @Test
+ void testMultipleConsumersForDefaultScheduler() throws Exception {
+ testResultPartitionConsumedByMultiConsumers(false);
+ }
+
+ private void testResultPartitionConsumedByMultiConsumers(boolean isAdaptive) throws Exception {
+ int parallelism = 2;
+ JobID jobId = new JobID();
+ JobVertex producer = ExecutionGraphTestUtils.createNoOpVertex("producer", parallelism);
+ JobVertex consumer1 = ExecutionGraphTestUtils.createNoOpVertex("consumer1", parallelism);
+ JobVertex consumer2 = ExecutionGraphTestUtils.createNoOpVertex("consumer2", parallelism);
+
+ TestingPartitionTracker partitionTracker = new TestingPartitionTracker();
+ SchedulerBase scheduler =
+ createSchedulerAndDeploy(
+ isAdaptive,
+ jobId,
+ producer,
+ new JobVertex[] {consumer1, consumer2},
+ DistributionPattern.ALL_TO_ALL,
+ new TestingBlobWriter(Integer.MAX_VALUE),
+ mainThreadExecutor,
+ ioExecutor,
+ partitionTracker,
+ EXECUTOR_RESOURCE.getExecutor());
+ ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+
+ assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+ CompletableFuture.runAsync(
+ () -> finishJobVertex(executionGraph, consumer1.getID()),
+ mainThreadExecutor)
+ .join();
+ ioExecutor.triggerAll();
+
+ assertThat(partitionTracker.releasedPartitions).isEmpty();
+
+ CompletableFuture.runAsync(
+ () -> finishJobVertex(executionGraph, consumer2.getID()),
+ mainThreadExecutor)
+ .join();
+ ioExecutor.triggerAll();
+
+ assertThat(partitionTracker.releasedPartitions.size()).isEqualTo(parallelism);
+ for (int i = 0; i < parallelism; ++i) {
+ ExecutionJobVertex ejv = checkNotNull(executionGraph.getJobVertex(producer.getID()));
+ assertThat(
+ partitionTracker.releasedPartitions.stream()
+ .map(ResultPartitionID::getPartitionId))
+ .containsExactlyInAnyOrder(
+ Arrays.stream(ejv.getProducedDataSets()[0].getPartitions())
+ .map(IntermediateResultPartition::getPartitionId)
+ .toArray(IntermediateResultPartitionID[]::new));
+ }
+ }
+
+ private static class TestingPartitionTracker extends NoOpJobMasterPartitionTracker {
+
+ private final List<ResultPartitionID> releasedPartitions = new ArrayList<>();
+
+ @Override
+ public void stopTrackingAndReleasePartitions(
+ Collection<ResultPartitionID> resultPartitionIds, boolean releaseOnShuffleMaster) {
+ releasedPartitions.addAll(checkNotNull(resultPartitionIds));
+ }
+ }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
index 9245fa8f966..588a91c6fd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraphConstructionTest.java
@@ -24,6 +24,7 @@ import org.apache.flink.core.io.InputSplitSource;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -48,6 +49,7 @@ import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
+import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -261,6 +263,45 @@ class DefaultExecutionGraphConstructionTest {
assertThat(eg.getAllVertices().get(v5.getID()).getSplitAssigner()).isEqualTo(assigner2);
}
+ @Test
+ void testMultiConsumersForOneIntermediateResult() throws Exception {
+ JobVertex v1 = new JobVertex("vertex1");
+ JobVertex v2 = new JobVertex("vertex2");
+ JobVertex v3 = new JobVertex("vertex3");
+
+ IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+ v2.connectNewDataSetAsInput(
+ v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+ v3.connectNewDataSetAsInput(
+ v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING, dataSetId, false);
+
+ List<JobVertex> vertices = new ArrayList<>(Arrays.asList(v1, v2, v3));
+ ExecutionGraph eg = createDefaultExecutionGraph(vertices);
+ eg.attachJobGraph(vertices);
+
+ ExecutionJobVertex ejv1 = checkNotNull(eg.getJobVertex(v1.getID()));
+ assertThat(ejv1.getProducedDataSets()).hasSize(1);
+ assertThat(ejv1.getProducedDataSets()[0].getId()).isEqualTo(dataSetId);
+
+ ExecutionJobVertex ejv2 = checkNotNull(eg.getJobVertex(v2.getID()));
+ assertThat(ejv2.getInputs()).hasSize(1);
+ assertThat(ejv2.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+ ExecutionJobVertex ejv3 = checkNotNull(eg.getJobVertex(v3.getID()));
+ assertThat(ejv3.getInputs()).hasSize(1);
+ assertThat(ejv3.getInputs().get(0).getId()).isEqualTo(dataSetId);
+
+ List<ConsumedPartitionGroup> partitionGroups1 =
+ ejv2.getTaskVertices()[0].getAllConsumedPartitionGroups();
+ assertThat(partitionGroups1).hasSize(1);
+ assertThat(partitionGroups1.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+
+ List<ConsumedPartitionGroup> partitionGroups2 =
+ ejv3.getTaskVertices()[0].getAllConsumedPartitionGroups();
+ assertThat(partitionGroups2).hasSize(1);
+ assertThat(partitionGroups2.get(0).getIntermediateDataSetID()).isEqualTo(dataSetId);
+ }
+
@Test
void testRegisterConsumedPartitionGroupToEdgeManager() throws Exception {
JobVertex v1 = new JobVertex("source");
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
index e3b603a8790..28b44595a26 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerBuildUtilTest.java
@@ -85,7 +85,7 @@ class EdgeManagerBuildUtilTest {
IntermediateResultPartition partition =
ev.getProducedPartitions().values().iterator().next();
- ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroup();
+ ConsumerVertexGroup consumerVertexGroup = partition.getConsumerVertexGroups().get(0);
int actual = consumerVertexGroup.size();
if (actual > actualMaxForUpstream) {
actualMaxForUpstream = actual;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
index 81165da6193..ad94749e7e8 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/EdgeManagerTest.java
@@ -34,9 +34,15 @@ import org.apache.flink.testutils.executor.TestExecutorExtension;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
+import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
+import java.util.stream.Collectors;
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.ALL_TO_ALL;
+import static org.apache.flink.runtime.jobgraph.DistributionPattern.POINTWISE;
+import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.assertj.core.api.Assertions.assertThat;
/** Tests for {@link EdgeManager}. */
@@ -50,23 +56,7 @@ class EdgeManagerTest {
void testGetConsumedPartitionGroup() throws Exception {
JobVertex v1 = new JobVertex("source");
JobVertex v2 = new JobVertex("sink");
-
- v1.setParallelism(2);
- v2.setParallelism(2);
-
- v1.setInvokableClass(NoOpInvokable.class);
- v2.setInvokableClass(NoOpInvokable.class);
-
- v2.connectNewDataSetAsInput(
- v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
-
- JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(v1, v2);
- SchedulerBase scheduler =
- SchedulerTestingUtils.createScheduler(
- jobGraph,
- ComponentMainThreadExecutorServiceAdapter.forMainThread(),
- EXECUTOR_RESOURCE.getExecutor());
- ExecutionGraph eg = scheduler.getExecutionGraph();
+ ExecutionGraph eg = buildExecutionGraph(v1, v2, 2, 2, ALL_TO_ALL);
ConsumedPartitionGroup groupRetrievedByDownstreamVertex =
Objects.requireNonNull(eg.getJobVertex(v2.getID()))
@@ -86,9 +76,7 @@ class EdgeManagerTest {
.isEqualTo(groupRetrievedByDownstreamVertex);
ConsumedPartitionGroup groupRetrievedByScheduledResultPartition =
- scheduler
- .getExecutionGraph()
- .getSchedulingTopology()
+ eg.getSchedulingTopology()
.getResultPartition(consumedPartition.getPartitionId())
.getConsumedPartitionGroups()
.get(0);
@@ -96,4 +84,64 @@ class EdgeManagerTest {
assertThat(groupRetrievedByScheduledResultPartition)
.isEqualTo(groupRetrievedByDownstreamVertex);
}
+
+ @Test
+ public void testCalculateNumberOfConsumers() throws Exception {
+ testCalculateNumberOfConsumers(5, 2, ALL_TO_ALL, new int[] {2, 2});
+ testCalculateNumberOfConsumers(5, 2, POINTWISE, new int[] {1, 1});
+ testCalculateNumberOfConsumers(2, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+ testCalculateNumberOfConsumers(2, 5, POINTWISE, new int[] {3, 3, 3, 2, 2});
+ testCalculateNumberOfConsumers(5, 5, ALL_TO_ALL, new int[] {5, 5, 5, 5, 5});
+ testCalculateNumberOfConsumers(5, 5, POINTWISE, new int[] {1, 1, 1, 1, 1});
+ }
+
+ private void testCalculateNumberOfConsumers(
+ int producerParallelism,
+ int consumerParallelism,
+ DistributionPattern distributionPattern,
+ int[] expectedConsumers)
+ throws Exception {
+ JobVertex producer = new JobVertex("producer");
+ JobVertex consumer = new JobVertex("consumer");
+ ExecutionGraph eg =
+ buildExecutionGraph(
+ producer,
+ consumer,
+ producerParallelism,
+ consumerParallelism,
+ distributionPattern);
+ List<ConsumedPartitionGroup> partitionGroups =
+ Arrays.stream(checkNotNull(eg.getJobVertex(consumer.getID())).getTaskVertices())
+ .flatMap(ev -> ev.getAllConsumedPartitionGroups().stream())
+ .collect(Collectors.toList());
+ int index = 0;
+ for (ConsumedPartitionGroup partitionGroup : partitionGroups) {
+ assertThat(partitionGroup.getNumConsumers()).isEqualTo(expectedConsumers[index++]);
+ }
+ }
+
+ private ExecutionGraph buildExecutionGraph(
+ JobVertex producer,
+ JobVertex consumer,
+ int producerParallelism,
+ int consumerParallelism,
+ DistributionPattern distributionPattern)
+ throws Exception {
+ producer.setParallelism(producerParallelism);
+ consumer.setParallelism(consumerParallelism);
+
+ producer.setInvokableClass(NoOpInvokable.class);
+ consumer.setInvokableClass(NoOpInvokable.class);
+
+ consumer.connectNewDataSetAsInput(
+ producer, distributionPattern, ResultPartitionType.BLOCKING);
+
+ JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(producer, consumer);
+ SchedulerBase scheduler =
+ SchedulerTestingUtils.createScheduler(
+ jobGraph,
+ ComponentMainThreadExecutorServiceAdapter.forMainThread(),
+ EXECUTOR_RESOURCE.getExecutor());
+ return scheduler.getExecutionGraph();
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
index ff2baf807e1..a9e7df9470d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionGraphTestUtils.java
@@ -35,6 +35,7 @@ import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
+import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.apache.flink.runtime.testutils.DirectScheduledExecutorService;
@@ -43,6 +44,7 @@ import javax.annotation.Nullable;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.List;
+import java.util.Objects;
import java.util.Random;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
@@ -249,6 +251,23 @@ public class ExecutionGraphTestUtils {
}
}
+ public static void finishJobVertex(ExecutionGraph executionGraph, JobVertexID jobVertexId) {
+ for (ExecutionVertex vertex :
+ Objects.requireNonNull(executionGraph.getJobVertex(jobVertexId))
+ .getTaskVertices()) {
+ finishExecutionVertex(executionGraph, vertex);
+ }
+ }
+
+ public static void finishExecutionVertex(
+ ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
+ executionGraph.updateState(
+ new TaskExecutionStateTransition(
+ new TaskExecutionState(
+ executionVertex.getCurrentExecutionAttempt().getAttemptId(),
+ ExecutionState.FINISHED)));
+ }
+
/**
* Takes all vertices in the given ExecutionGraph and switches their current execution to
* FINISHED.
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index d847e9ede68..dda16596751 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -170,7 +170,7 @@ class ExecutionJobVertexTest {
int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception {
JobVertex jobVertex = new JobVertex("testVertex");
jobVertex.setInvokableClass(AbstractInvokable.class);
- jobVertex.createAndAddResultDataSet(
+ jobVertex.getOrCreateResultDataSet(
new IntermediateDataSetID(), ResultPartitionType.BLOCKING);
if (maxParallelism > 0) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 5ff5d9f201c..aa906caf674 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAdapter;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
@@ -151,6 +152,32 @@ public class IntermediateResultPartitionTest {
assertThat(consumedPartitionGroup.areAllPartitionsFinished()).isFalse();
}
+ @Test
+ void testReleasePartitionGroups() throws Exception {
+ IntermediateResult result = createResult(ResultPartitionType.BLOCKING, 2);
+
+ IntermediateResultPartition partition1 = result.getPartitions()[0];
+ IntermediateResultPartition partition2 = result.getPartitions()[1];
+ assertThat(partition1.canBeReleased()).isFalse();
+ assertThat(partition2.canBeReleased()).isFalse();
+
+ List<ConsumedPartitionGroup> consumedPartitionGroup1 =
+ partition1.getConsumedPartitionGroups();
+ List<ConsumedPartitionGroup> consumedPartitionGroup2 =
+ partition2.getConsumedPartitionGroups();
+ assertThat(consumedPartitionGroup1).isEqualTo(consumedPartitionGroup2);
+
+ assertThat(consumedPartitionGroup1).hasSize(2);
+ partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(0));
+ assertThat(partition1.canBeReleased()).isFalse();
+
+ partition1.markPartitionGroupReleasable(consumedPartitionGroup1.get(1));
+ assertThat(partition1.canBeReleased()).isTrue();
+
+ result.resetForNewExecution();
+ assertThat(partition1.canBeReleased()).isFalse();
+ }
+
@Test
void testGetNumberOfSubpartitionsForNonDynamicAllToAllGraph() throws Exception {
testGetNumberOfSubpartitions(7, DistributionPattern.ALL_TO_ALL, false, Arrays.asList(7, 7));
@@ -245,11 +272,24 @@ public class IntermediateResultPartitionTest {
v2.setMaxParallelism(consumerMaxParallelism);
}
- v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
+ final JobVertex v3 = new JobVertex("v3");
+ v3.setInvokableClass(NoOpInvokable.class);
+ if (consumerParallelism > 0) {
+ v3.setParallelism(consumerParallelism);
+ }
+ if (consumerMaxParallelism > 0) {
+ v3.setMaxParallelism(consumerMaxParallelism);
+ }
+
+ IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+ v2.connectNewDataSetAsInput(
+ v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+ v3.connectNewDataSetAsInput(
+ v1, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
final JobGraph jobGraph =
JobGraphBuilder.newBatchJobGraphBuilder()
- .addJobVertices(Arrays.asList(v1, v2))
+ .addJobVertices(Arrays.asList(v1, v2, v3))
.build();
final Configuration configuration = new Configuration();
@@ -287,15 +327,23 @@ public class IntermediateResultPartitionTest {
source.setInvokableClass(NoOpInvokable.class);
source.setParallelism(parallelism);
- JobVertex sink = new JobVertex("v2");
- sink.setInvokableClass(NoOpInvokable.class);
- sink.setParallelism(parallelism);
+ JobVertex sink1 = new JobVertex("v2");
+ sink1.setInvokableClass(NoOpInvokable.class);
+ sink1.setParallelism(parallelism);
+
+ JobVertex sink2 = new JobVertex("v3");
+ sink2.setInvokableClass(NoOpInvokable.class);
+ sink2.setParallelism(parallelism);
- sink.connectNewDataSetAsInput(source, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+ IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+ sink1.connectNewDataSetAsInput(
+ source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
+ sink2.connectNewDataSetAsInput(
+ source, DistributionPattern.ALL_TO_ALL, resultPartitionType, dataSetId, false);
ScheduledExecutorService executorService = new DirectScheduledExecutorService();
- JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink);
+ JobGraph jobGraph = JobGraphTestUtils.batchJobGraph(source, sink1, sink2);
SchedulerBase scheduler =
new DefaultSchedulerBuilder(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
index 7c00da22a85..4bf976c76b0 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/RemoveCachedShuffleDescriptorTest.java
@@ -19,7 +19,6 @@
package org.apache.flink.runtime.executiongraph;
import org.apache.flink.api.common.JobID;
-import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.blob.BlobWriter;
import org.apache.flink.runtime.blob.TestingBlobWriter;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
@@ -27,21 +26,14 @@ import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutorServiceAda
import org.apache.flink.runtime.concurrent.ManuallyTriggeredScheduledExecutorService;
import org.apache.flink.runtime.deployment.TaskDeploymentDescriptor.MaybeOffloaded;
import org.apache.flink.runtime.execution.ExecutionState;
-import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.NoOpJobMasterPartitionTracker;
import org.apache.flink.runtime.io.network.partition.PartitionNotFoundException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
-import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
-import org.apache.flink.runtime.jobgraph.JobGraph;
-import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
import org.apache.flink.runtime.jobgraph.JobVertex;
-import org.apache.flink.runtime.jobgraph.JobVertexID;
-import org.apache.flink.runtime.jobmaster.LogicalSlot;
-import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
-import org.apache.flink.runtime.scheduler.DefaultScheduler;
-import org.apache.flink.runtime.scheduler.DefaultSchedulerBuilder;
+import org.apache.flink.runtime.scheduler.SchedulerBase;
+import org.apache.flink.runtime.scheduler.SchedulerTestingUtils;
import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
-import org.apache.flink.runtime.taskmanager.TaskExecutionState;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorExtension;
@@ -50,17 +42,16 @@ import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
-import java.util.ArrayList;
-import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeoutException;
import static org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactoryTest.deserializeShuffleDescriptors;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishExecutionVertex;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
import static org.assertj.core.api.Assertions.assertThat;
/**
@@ -119,7 +110,7 @@ class RemoveCachedShuffleDescriptorTest {
final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
- final DefaultScheduler scheduler =
+ final SchedulerBase scheduler =
createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
@@ -132,8 +123,7 @@ class RemoveCachedShuffleDescriptorTest {
// For the all-to-all edge, we transition all downstream tasks to finished
CompletableFuture.runAsync(
- () -> transitionTasksToFinished(executionGraph, v2.getID()),
- mainThreadExecutor)
+ () -> finishJobVertex(executionGraph, v2.getID()), mainThreadExecutor)
.join();
ioExecutor.triggerAll();
@@ -164,7 +154,7 @@ class RemoveCachedShuffleDescriptorTest {
final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
- final DefaultScheduler scheduler =
+ final SchedulerBase scheduler =
createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.ALL_TO_ALL, blobWriter);
final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
@@ -206,7 +196,7 @@ class RemoveCachedShuffleDescriptorTest {
final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
- final DefaultScheduler scheduler =
+ final SchedulerBase scheduler =
createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
@@ -222,7 +212,7 @@ class RemoveCachedShuffleDescriptorTest {
Objects.requireNonNull(executionGraph.getJobVertex(v2.getID()))
.getTaskVertices()[0];
CompletableFuture.runAsync(
- () -> transitionTaskToFinished(executionGraph, ev21), mainThreadExecutor)
+ () -> finishExecutionVertex(executionGraph, ev21), mainThreadExecutor)
.join();
ioExecutor.triggerAll();
@@ -263,7 +253,7 @@ class RemoveCachedShuffleDescriptorTest {
final JobVertex v1 = ExecutionGraphTestUtils.createNoOpVertex("v1", PARALLELISM);
final JobVertex v2 = ExecutionGraphTestUtils.createNoOpVertex("v2", PARALLELISM);
- final DefaultScheduler scheduler =
+ final SchedulerBase scheduler =
createSchedulerAndDeploy(jobId, v1, v2, DistributionPattern.POINTWISE, blobWriter);
final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
@@ -292,42 +282,27 @@ class RemoveCachedShuffleDescriptorTest {
assertThat(blobWriter.numberOfBlobs()).isEqualTo(expectedAfter);
}
- private DefaultScheduler createSchedulerAndDeploy(
+ private SchedulerBase createSchedulerAndDeploy(
JobID jobId,
JobVertex v1,
JobVertex v2,
DistributionPattern distributionPattern,
BlobWriter blobWriter)
throws Exception {
-
- v2.connectNewDataSetAsInput(v1, distributionPattern, ResultPartitionType.BLOCKING);
-
- final List<JobVertex> ordered = new ArrayList<>(Arrays.asList(v1, v2));
- final DefaultScheduler scheduler =
- createScheduler(jobId, ordered, blobWriter, mainThreadExecutor, ioExecutor);
- final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
- final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
-
- CompletableFuture.runAsync(
- () -> {
- try {
- // Deploy upstream source vertices
- deployTasks(executionGraph, v1.getID(), slotBuilder);
- // Transition upstream vertices into FINISHED
- transitionTasksToFinished(executionGraph, v1.getID());
- // Deploy downstream sink vertices
- deployTasks(executionGraph, v2.getID(), slotBuilder);
- } catch (Exception e) {
- throw new RuntimeException("Exceptions shouldn't happen here.", e);
- }
- },
- mainThreadExecutor)
- .join();
-
- return scheduler;
+ return SchedulerTestingUtils.createSchedulerAndDeploy(
+ false,
+ jobId,
+ v1,
+ new JobVertex[] {v2},
+ distributionPattern,
+ blobWriter,
+ mainThreadExecutor,
+ ioExecutor,
+ NoOpJobMasterPartitionTracker.INSTANCE,
+ EXECUTOR_RESOURCE.getExecutor());
}
- private void triggerGlobalFailoverAndComplete(DefaultScheduler scheduler, JobVertex upstream)
+ private void triggerGlobalFailoverAndComplete(SchedulerBase scheduler, JobVertex upstream)
throws TimeoutException {
final Throwable t = new Exception();
@@ -378,66 +353,6 @@ class RemoveCachedShuffleDescriptorTest {
// ============== Utils ==============
- private static DefaultScheduler createScheduler(
- final JobID jobId,
- final List<JobVertex> jobVertices,
- final BlobWriter blobWriter,
- final ComponentMainThreadExecutor mainThreadExecutor,
- final ScheduledExecutorService ioExecutor)
- throws Exception {
- final JobGraph jobGraph =
- JobGraphBuilder.newBatchJobGraphBuilder()
- .setJobId(jobId)
- .addJobVertices(jobVertices)
- .build();
-
- return new DefaultSchedulerBuilder(
- jobGraph, mainThreadExecutor, EXECUTOR_RESOURCE.getExecutor())
- .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
- .setBlobWriter(blobWriter)
- .setIoExecutor(ioExecutor)
- .build();
- }
-
- private static void deployTasks(
- ExecutionGraph executionGraph,
- JobVertexID jobVertexID,
- TestingLogicalSlotBuilder slotBuilder)
- throws JobException, ExecutionException, InterruptedException {
-
- for (ExecutionVertex vertex :
- Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
- .getTaskVertices()) {
- LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
-
- Execution execution = vertex.getCurrentExecutionAttempt();
- execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
- execution.transitionState(ExecutionState.SCHEDULED);
-
- vertex.tryAssignResource(slot);
- vertex.deploy();
- }
- }
-
- private static void transitionTasksToFinished(
- ExecutionGraph executionGraph, JobVertexID jobVertexID) {
-
- for (ExecutionVertex vertex :
- Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
- .getTaskVertices()) {
- transitionTaskToFinished(executionGraph, vertex);
- }
- }
-
- private static void transitionTaskToFinished(
- ExecutionGraph executionGraph, ExecutionVertex executionVertex) {
- executionGraph.updateState(
- new TaskExecutionStateTransition(
- new TaskExecutionState(
- executionVertex.getCurrentExecutionAttempt().getAttemptId(),
- ExecutionState.FINISHED)));
- }
-
private static MaybeOffloaded<ShuffleDescriptor[]> getConsumedCachedShuffleDescriptor(
ExecutionGraph executionGraph, JobVertex vertex) {
return getConsumedCachedShuffleDescriptor(executionGraph, vertex, 0);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
index 087a29613e1..130212bcf05 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpJobMasterPartitionTracker.java
@@ -28,8 +28,9 @@ import java.util.Collections;
import java.util.List;
/** No-op implementation of {@link JobMasterPartitionTracker}. */
-public enum NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
- INSTANCE;
+public class NoOpJobMasterPartitionTracker implements JobMasterPartitionTracker {
+ public static final NoOpJobMasterPartitionTracker INSTANCE =
+ new NoOpJobMasterPartitionTracker();
public static final PartitionTrackerFactory FACTORY = lookup -> INSTANCE;
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
index 66e097b34b1..cf59753999f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobgraph/JobTaskVertexTest.java
@@ -34,6 +34,7 @@ import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.net.URL;
import java.net.URLClassLoader;
+import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -41,6 +42,45 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
@SuppressWarnings("serial")
class JobTaskVertexTest {
+ @Test
+ void testMultipleConsumersVertices() {
+ JobVertex producer = new JobVertex("producer");
+ JobVertex consumer1 = new JobVertex("consumer1");
+ JobVertex consumer2 = new JobVertex("consumer2");
+
+ IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+ consumer1.connectNewDataSetAsInput(
+ producer,
+ DistributionPattern.ALL_TO_ALL,
+ ResultPartitionType.BLOCKING,
+ dataSetId,
+ false);
+ consumer2.connectNewDataSetAsInput(
+ producer,
+ DistributionPattern.ALL_TO_ALL,
+ ResultPartitionType.BLOCKING,
+ dataSetId,
+ false);
+
+ JobVertex consumer3 = new JobVertex("consumer3");
+ consumer3.connectNewDataSetAsInput(
+ producer, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
+
+ assertThat(producer.getProducedDataSets()).hasSize(2);
+
+ IntermediateDataSet dataSet = producer.getProducedDataSets().get(0);
+ assertThat(dataSet.getId()).isEqualTo(dataSetId);
+
+ List<JobEdge> consumers1 = dataSet.getConsumers();
+ assertThat(consumers1).hasSize(2);
+ assertThat(consumers1.get(0).getTarget().getID()).isEqualTo(consumer1.getID());
+ assertThat(consumers1.get(1).getTarget().getID()).isEqualTo(consumer2.getID());
+
+ List<JobEdge> consumers2 = producer.getProducedDataSets().get(1).getConsumers();
+ assertThat(consumers2).hasSize(1);
+ assertThat(consumers2.get(0).getTarget().getID()).isEqualTo(consumer3.getID());
+ }
+
@Test
void testConnectDirectly() {
JobVertex source = new JobVertex("source");
@@ -59,7 +99,8 @@ class JobTaskVertexTest {
assertThat(source.getProducedDataSets().get(0))
.isEqualTo(target.getInputs().get(0).getSource());
- assertThat(source.getProducedDataSets().get(0).getConsumer().getTarget()).isEqualTo(target);
+ assertThat(source.getProducedDataSets().get(0).getConsumers().get(0).getTarget())
+ .isEqualTo(target);
}
@Test
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
index 7f9488fe901..a0ce9132e3f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmaster/JobIntermediateDatasetReuseTest.java
@@ -206,7 +206,8 @@ public class JobIntermediateDatasetReuseTest {
sender,
DistributionPattern.POINTWISE,
ResultPartitionType.BLOCKING_PERSISTENT,
- intermediateDataSetID);
+ intermediateDataSetID,
+ false);
return new JobGraph(null, "First Job", sender, receiver);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
index eb09a7e139c..b3954d2ea07 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SchedulerTestingUtils.java
@@ -20,6 +20,8 @@ package org.apache.flink.runtime.scheduler;
import org.apache.flink.api.common.JobID;
import org.apache.flink.api.common.time.Time;
+import org.apache.flink.runtime.JobException;
+import org.apache.flink.runtime.blob.BlobWriter;
import org.apache.flink.runtime.checkpoint.CheckpointCoordinator;
import org.apache.flink.runtime.checkpoint.CheckpointException;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
@@ -28,12 +30,24 @@ import org.apache.flink.runtime.checkpoint.CompletedCheckpoint;
import org.apache.flink.runtime.checkpoint.PendingCheckpoint;
import org.apache.flink.runtime.concurrent.ComponentMainThreadExecutor;
import org.apache.flink.runtime.execution.ExecutionState;
+import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.ExecutionGraph;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
+import org.apache.flink.runtime.executiongraph.failover.flip1.TestRestartBackoffTimeStrategy;
+import org.apache.flink.runtime.io.network.partition.JobMasterPartitionTracker;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobGraphBuilder;
+import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
+import org.apache.flink.runtime.jobmaster.LogicalSlot;
+import org.apache.flink.runtime.jobmaster.TestingLogicalSlotBuilder;
import org.apache.flink.runtime.jobmaster.slotpool.PhysicalSlotProvider;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.scheduler.strategy.ExecutionVertexID;
@@ -46,12 +60,18 @@ import org.apache.flink.util.TernaryBoolean;
import javax.annotation.Nullable;
import java.io.IOException;
+import java.util.ArrayList;
import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Objects;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import java.util.concurrent.ScheduledExecutorService;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;
+import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.finishJobVertex;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
@@ -297,4 +317,121 @@ public class SchedulerTestingUtils {
allocationTimeout,
new LocalInputPreferredSlotSharingStrategy.Factory());
}
+
+ public static SchedulerBase createSchedulerAndDeploy(
+ boolean isAdaptive,
+ JobID jobId,
+ JobVertex producer,
+ JobVertex[] consumers,
+ DistributionPattern distributionPattern,
+ BlobWriter blobWriter,
+ ComponentMainThreadExecutor mainThreadExecutor,
+ ScheduledExecutorService ioExecutor,
+ JobMasterPartitionTracker partitionTracker,
+ ScheduledExecutorService scheduledExecutor)
+ throws Exception {
+ final List<JobVertex> vertices = new ArrayList<>(Collections.singletonList(producer));
+ IntermediateDataSetID dataSetId = new IntermediateDataSetID();
+ for (JobVertex consumer : consumers) {
+ consumer.connectNewDataSetAsInput(
+ producer, distributionPattern, ResultPartitionType.BLOCKING, dataSetId, false);
+ vertices.add(consumer);
+ }
+
+ final SchedulerBase scheduler =
+ createScheduler(
+ isAdaptive,
+ jobId,
+ vertices,
+ blobWriter,
+ mainThreadExecutor,
+ ioExecutor,
+ partitionTracker,
+ scheduledExecutor);
+ final ExecutionGraph executionGraph = scheduler.getExecutionGraph();
+ final TestingLogicalSlotBuilder slotBuilder = new TestingLogicalSlotBuilder();
+
+ CompletableFuture.runAsync(
+ () -> {
+ try {
+ if (isAdaptive) {
+ initializeExecutionJobVertex(producer.getID(), executionGraph);
+ }
+ // Deploy upstream source vertices
+ deployTasks(executionGraph, producer.getID(), slotBuilder);
+ // Transition upstream vertices into FINISHED
+ finishJobVertex(executionGraph, producer.getID());
+ // Deploy downstream sink vertices
+ for (JobVertex consumer : consumers) {
+ if (isAdaptive) {
+ initializeExecutionJobVertex(
+ consumer.getID(), executionGraph);
+ }
+ deployTasks(executionGraph, consumer.getID(), slotBuilder);
+ }
+ } catch (Exception e) {
+ throw new RuntimeException("Exceptions shouldn't happen here.", e);
+ }
+ },
+ mainThreadExecutor)
+ .join();
+ return scheduler;
+ }
+
+ private static void initializeExecutionJobVertex(
+ JobVertexID jobVertex, ExecutionGraph executionGraph) {
+ try {
+ executionGraph.initializeJobVertex(
+ executionGraph.getJobVertex(jobVertex), System.currentTimeMillis());
+ executionGraph.notifyNewlyInitializedJobVertices(
+ Collections.singletonList(executionGraph.getJobVertex(jobVertex)));
+ } catch (JobException exception) {
+ throw new RuntimeException(exception);
+ }
+ }
+
+ private static DefaultScheduler createScheduler(
+ boolean isAdaptive,
+ JobID jobId,
+ List<JobVertex> jobVertices,
+ BlobWriter blobWriter,
+ ComponentMainThreadExecutor mainThreadExecutor,
+ ScheduledExecutorService ioExecutor,
+ JobMasterPartitionTracker partitionTracker,
+ ScheduledExecutorService scheduledExecutor)
+ throws Exception {
+ final JobGraph jobGraph =
+ JobGraphBuilder.newBatchJobGraphBuilder()
+ .setJobId(jobId)
+ .addJobVertices(jobVertices)
+ .build();
+
+ final DefaultSchedulerBuilder builder =
+ new DefaultSchedulerBuilder(jobGraph, mainThreadExecutor, scheduledExecutor)
+ .setRestartBackoffTimeStrategy(new TestRestartBackoffTimeStrategy(true, 0))
+ .setBlobWriter(blobWriter)
+ .setIoExecutor(ioExecutor)
+ .setPartitionTracker(partitionTracker);
+ return isAdaptive ? builder.buildAdaptiveBatchJobScheduler() : builder.build();
+ }
+
+ private static void deployTasks(
+ ExecutionGraph executionGraph,
+ JobVertexID jobVertexID,
+ TestingLogicalSlotBuilder slotBuilder)
+ throws JobException, ExecutionException, InterruptedException {
+
+ for (ExecutionVertex vertex :
+ Objects.requireNonNull(executionGraph.getJobVertex(jobVertexID))
+ .getTaskVertices()) {
+ LogicalSlot slot = slotBuilder.createTestingLogicalSlot();
+
+ Execution execution = vertex.getCurrentExecutionAttempt();
+ execution.registerProducedPartitions(slot.getTaskManagerLocation()).get();
+ execution.transitionState(ExecutionState.SCHEDULED);
+
+ vertex.tryAssignResource(slot);
+ vertex.deploy();
+ }
+ }
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
index 4e1d9cb76a4..f22c3b23dd7 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionTopologyTest.java
@@ -52,7 +52,6 @@ import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
@@ -314,17 +313,23 @@ class DefaultExecutionTopologyTest {
assertPartitionEquals(originalPartition, adaptedPartition);
- ConsumerVertexGroup consumerVertexGroup = originalPartition.getConsumerVertexGroup();
- Optional<ConsumerVertexGroup> adaptedConsumers =
- adaptedPartition.getConsumerVertexGroup();
- assertThat(adaptedConsumers).isPresent();
- for (ExecutionVertexID originalId : consumerVertexGroup) {
+ List<ExecutionVertexID> originalConsumerIds = new ArrayList<>();
+ for (ConsumerVertexGroup consumerVertexGroup :
+ originalPartition.getConsumerVertexGroups()) {
+ for (ExecutionVertexID executionVertexId : consumerVertexGroup) {
+ originalConsumerIds.add(executionVertexId);
+ }
+ }
+ List<ConsumerVertexGroup> adaptedConsumers = adaptedPartition.getConsumerVertexGroups();
+ assertThat(adaptedConsumers).isNotEmpty();
+ for (ExecutionVertexID originalId : originalConsumerIds) {
// it is sufficient to verify that some vertex exists with the correct ID here,
// since deep equality is verified later in the main loop
// this DOES rely on an implicit assumption that the vertices objects returned by
// the topology are
// identical to those stored in the partition
- assertThat(adaptedConsumers.get()).contains(originalId);
+ assertThat(adaptedConsumers.stream().flatMap(IterableUtils::toStream))
+ .contains(originalId);
}
}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
index fc2df697464..d6fb9ace1a3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultExecutionVertexTest.java
@@ -81,6 +81,7 @@ class DefaultExecutionVertexTest {
List<ConsumedPartitionGroup> consumedPartitionGroups =
Collections.singletonList(
ConsumedPartitionGroup.fromSinglePartition(
+ 1,
intermediateResultPartitionId,
schedulingResultPartition.getResultType()));
Map<IntermediateResultPartitionID, DefaultResultPartition> resultPartitionById =
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
index 6d0024626c1..9f8c58400e3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adapter/DefaultResultPartitionTest.java
@@ -28,7 +28,10 @@ import org.apache.flink.runtime.scheduler.strategy.ResultPartitionState;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
+import java.util.ArrayList;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
@@ -47,8 +50,8 @@ class DefaultResultPartitionTest {
private DefaultResultPartition resultPartition;
- private final Map<IntermediateResultPartitionID, ConsumerVertexGroup> consumerVertexGroups =
- new HashMap<>();
+ private final Map<IntermediateResultPartitionID, List<ConsumerVertexGroup>>
+ consumerVertexGroups = new HashMap<>();
@BeforeEach
void setUp() {
@@ -58,7 +61,9 @@ class DefaultResultPartitionTest {
intermediateResultId,
BLOCKING,
resultPartitionState,
- () -> consumerVertexGroups.get(resultPartitionId),
+ () ->
+ consumerVertexGroups.computeIfAbsent(
+ resultPartitionId, ignored -> new ArrayList<>()),
() -> {
throw new UnsupportedOperationException();
});
@@ -75,14 +80,15 @@ class DefaultResultPartitionTest {
@Test
void testGetConsumerVertexGroup() {
- assertThat(resultPartition.getConsumerVertexGroup()).isNotPresent();
+ assertThat(resultPartition.getConsumerVertexGroups()).isEmpty();
// test update consumers
ExecutionVertexID executionVertexId = new ExecutionVertexID(new JobVertexID(), 0);
consumerVertexGroups.put(
- resultPartition.getId(), ConsumerVertexGroup.fromSingleVertex(executionVertexId));
- assertThat(resultPartition.getConsumerVertexGroup()).isPresent();
- assertThat(resultPartition.getConsumerVertexGroup().get()).contains(executionVertexId);
+ resultPartition.getId(),
+ Collections.singletonList(ConsumerVertexGroup.fromSingleVertex(executionVertexId)));
+ assertThat(resultPartition.getConsumerVertexGroups()).isNotEmpty();
+ assertThat(resultPartition.getConsumerVertexGroups().get(0)).contains(executionVertexId);
}
/** A test {@link ResultPartitionState} supplier. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
index 3c9bed61354..a6e87f9bfc2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/AdaptiveBatchSchedulerTest.java
@@ -166,7 +166,7 @@ class AdaptiveBatchSchedulerTest {
sink.connectNewDataSetAsInput(
source2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
if (withForwardEdge) {
- source1.getProducedDataSets().get(0).getConsumer().setForward(true);
+ source1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
}
return new JobGraph(new JobID(), "test job", source1, source2, sink);
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
index b45ba3f5ebe..e9b0da1509d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/adaptivebatch/forwardgroup/ForwardGroupComputeUtilTest.java
@@ -97,14 +97,14 @@ class ForwardGroupComputeUtilTest {
v2.connectNewDataSetAsInput(
v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
if (isForward1) {
- v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+ v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
}
v3.connectNewDataSetAsInput(
v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
if (isForward2) {
- v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+ v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
}
Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3);
@@ -135,10 +135,10 @@ class ForwardGroupComputeUtilTest {
v3.connectNewDataSetAsInput(
v1, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
- v1.getProducedDataSets().get(0).getConsumer().setForward(true);
+ v1.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
v3.connectNewDataSetAsInput(
v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
- v2.getProducedDataSets().get(0).getConsumer().setForward(true);
+ v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
v4.connectNewDataSetAsInput(
v3, DistributionPattern.ALL_TO_ALL, ResultPartitionType.BLOCKING);
@@ -174,8 +174,8 @@ class ForwardGroupComputeUtilTest {
v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
v4.connectNewDataSetAsInput(
v2, DistributionPattern.POINTWISE, ResultPartitionType.BLOCKING);
- v2.getProducedDataSets().get(0).getConsumer().setForward(true);
- v2.getProducedDataSets().get(1).getConsumer().setForward(true);
+ v2.getProducedDataSets().get(0).getConsumers().get(0).setForward(true);
+ v2.getProducedDataSets().get(1).getConsumers().get(0).setForward(true);
Set<ForwardGroup> groups = computeForwardGroups(v1, v2, v3, v4);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
index 4621680d360..655f725874c 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingExecutionVertex.java
@@ -93,7 +93,9 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
void addConsumedPartition(TestingSchedulingResultPartition consumedPartition) {
final ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromSinglePartition(
- consumedPartition.getId(), consumedPartition.getResultType());
+ consumedPartition.getNumConsumers(),
+ consumedPartition.getId(),
+ consumedPartition.getResultType());
consumedPartition.registerConsumedPartitionGroup(consumedPartitionGroup);
if (consumedPartition.getState() == ResultPartitionState.CONSUMABLE) {
@@ -155,7 +157,8 @@ public class TestingSchedulingExecutionVertex implements SchedulingExecutionVert
partitionIds.add(partitionId);
}
this.consumedPartitionGroups.add(
- ConsumedPartitionGroup.fromMultiplePartitions(partitionIds, resultType));
+ ConsumedPartitionGroup.fromMultiplePartitions(
+ partitionGroup.getNumConsumers(), partitionIds, resultType));
}
return this;
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
index 9454268cb5f..6274eecb285 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingResultPartition.java
@@ -28,7 +28,6 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
-import java.util.Optional;
import java.util.stream.Collectors;
import static org.apache.flink.util.Preconditions.checkNotNull;
@@ -64,6 +63,10 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
this.consumedPartitionGroups = new ArrayList<>();
}
+ public int getNumConsumers() {
+ return consumerVertexGroup == null ? 1 : consumerVertexGroup.size();
+ }
+
@Override
public IntermediateResultPartitionID getId() {
return intermediateResultPartitionID;
@@ -90,8 +93,8 @@ public class TestingSchedulingResultPartition implements SchedulingResultPartiti
}
@Override
- public Optional<ConsumerVertexGroup> getConsumerVertexGroup() {
- return Optional.of(consumerVertexGroup);
+ public List<ConsumerVertexGroup> getConsumerVertexGroups() {
+ return Collections.singletonList(consumerVertexGroup);
}
@Override
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
index 82463a1a2cb..95596369e33 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/strategy/TestingSchedulingTopology.java
@@ -350,6 +350,7 @@ public class TestingSchedulingTopology implements SchedulingTopology {
ConsumedPartitionGroup consumedPartitionGroup =
ConsumedPartitionGroup.fromMultiplePartitions(
+ consumers.size(),
resultPartitions.stream()
.map(TestingSchedulingResultPartition::getId)
.collect(Collectors.toList()),
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
index be72cacff5c..4cf9ed7c990 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGenerator.java
@@ -1079,19 +1079,20 @@ public class StreamingJobGraphGenerator {
headVertex,
DistributionPattern.POINTWISE,
resultPartitionType,
- intermediateDataSetID);
+ intermediateDataSetID,
+ partitioner.isBroadcast());
} else {
jobEdge =
downStreamVertex.connectNewDataSetAsInput(
headVertex,
DistributionPattern.ALL_TO_ALL,
resultPartitionType,
- intermediateDataSetID);
+ intermediateDataSetID,
+ partitioner.isBroadcast());
}
// set strategy name so that web interface can show it.
jobEdge.setShipStrategyName(partitioner.toString());
- jobEdge.setBroadcast(partitioner.isBroadcast());
jobEdge.setForward(partitioner instanceof ForwardPartitioner);
jobEdge.setDownstreamSubtaskStateMapper(partitioner.getDownstreamSubtaskStateMapper());
jobEdge.setUpstreamSubtaskStateMapper(partitioner.getUpstreamSubtaskStateMapper());
diff --git a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
index ab50814edd2..137e5040383 100644
--- a/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
+++ b/flink-tests/src/test/java/org/apache/flink/runtime/operators/lifecycle/validation/TestJobDataFlowValidator.java
@@ -55,24 +55,26 @@ public class TestJobDataFlowValidator {
for (JobVertex upstream : testJob.jobGraph.getVertices()) {
for (IntermediateDataSet produced : upstream.getProducedDataSets()) {
- JobEdge edge = produced.getConsumer();
- Optional<String> upstreamIDOptional = getTrackedOperatorID(upstream, true, testJob);
- Optional<String> downstreamIDOptional =
- getTrackedOperatorID(edge.getTarget(), false, testJob);
- if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
- final String upstreamID = upstreamIDOptional.get();
- final String downstreamID = downstreamIDOptional.get();
- if (testJob.sources.contains(upstreamID)) {
- // TODO: if we add tests for FLIP-27 sources we might need to adjust
- // this condition
- LOG.debug(
- "Legacy sources do not have the finish() method and thus do not"
- + " emit FinishEvent");
+ for (JobEdge edge : produced.getConsumers()) {
+ Optional<String> upstreamIDOptional =
+ getTrackedOperatorID(upstream, true, testJob);
+ Optional<String> downstreamIDOptional =
+ getTrackedOperatorID(edge.getTarget(), false, testJob);
+ if (upstreamIDOptional.isPresent() && downstreamIDOptional.isPresent()) {
+ final String upstreamID = upstreamIDOptional.get();
+ final String downstreamID = downstreamIDOptional.get();
+ if (testJob.sources.contains(upstreamID)) {
+ // TODO: if we add tests for FLIP-27 sources we might need to adjust
+ // this condition
+ LOG.debug(
+ "Legacy sources do not have the finish() method and thus do not"
+ + " emit FinishEvent");
+ } else {
+ checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+ }
} else {
- checkDataFlow(upstreamID, downstreamID, edge, finishEvents, withDrain);
+ LOG.debug("Ignoring edge (untracked operator): {}", edge);
}
- } else {
- LOG.debug("Ignoring edge (untracked operator): {}", edge);
}
}
}