You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/01/25 18:34:52 UTC
[flink] branch master updated: [FLINK-25668][runtime] Support to compute network memory for dynamic graph.
This is an automated email from the ASF dual-hosted git repository.
zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 7962a5d [FLINK-25668][runtime] Support to compute network memory for dynamic graph.
7962a5d is described below
commit 7962a5d9c47f49f435f822a1af2c4141c42a849b
Author: Lijie Wang <wa...@gmail.com>
AuthorDate: Tue Dec 14 22:04:20 2021 +0800
[FLINK-25668][runtime] Support to compute network memory for dynamic graph.
This closes #18376.
---
.../runtime/deployment/SubpartitionIndexRange.java | 4 +
.../TaskDeploymentDescriptorFactory.java | 24 ++-
.../executiongraph/DefaultExecutionGraph.java | 20 ++
.../runtime/executiongraph/IntermediateResult.java | 2 +-
.../flink/runtime/scheduler/DefaultScheduler.java | 11 -
.../SsgNetworkMemoryCalculationUtils.java | 63 +++++-
.../executiongraph/ExecutionJobVertexTest.java | 4 +-
.../IntermediateResultPartitionTest.java | 2 +-
.../SsgNetworkMemoryCalculationUtilsTest.java | 228 +++++++++++++++++----
9 files changed, 291 insertions(+), 67 deletions(-)
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
index 1fb1d52..19484a6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
@@ -43,6 +43,10 @@ public class SubpartitionIndexRange implements Serializable {
return endIndex;
}
+ public int size() {
+ return endIndex - startIndex + 1;
+ }
+
@Override
public String toString() {
return String.format("[%d, %d]", startIndex, endIndex);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index 528a954..bd0f5b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -128,17 +128,9 @@ public class TaskDeploymentDescriptorFactory {
IntermediateResultPartition resultPartition =
resultPartitionRetriever.apply(consumedPartitionGroup.getFirst());
- int numConsumers = resultPartition.getConsumerVertexGroup().size();
IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
- int consumerIndex = subtaskIndex % numConsumers;
- int numSubpartitions = resultPartition.getNumberOfSubpartitions();
SubpartitionIndexRange consumedSubpartitionRange =
- computeConsumedSubpartitionRange(
- consumerIndex,
- numConsumers,
- numSubpartitions,
- consumedIntermediateResult.getProducer().getGraph().isDynamic(),
- consumedIntermediateResult.isBroadcast());
+ computeConsumedSubpartitionRange(resultPartition, subtaskIndex);
IntermediateDataSetID resultId = consumedIntermediateResult.getId();
ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
@@ -155,6 +147,20 @@ public class TaskDeploymentDescriptorFactory {
return inputGates;
}
+ public static SubpartitionIndexRange computeConsumedSubpartitionRange(
+ IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
+ int numConsumers = resultPartition.getConsumerVertexGroup().size();
+ int consumerIndex = consumerSubtaskIndex % numConsumers;
+ IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
+ int numSubpartitions = resultPartition.getNumberOfSubpartitions();
+ return computeConsumedSubpartitionRange(
+ consumerIndex,
+ numConsumers,
+ numSubpartitions,
+ consumedIntermediateResult.getProducer().getGraph().isDynamic(),
+ consumedIntermediateResult.isBroadcast());
+ }
+
@VisibleForTesting
static SubpartitionIndexRange computeConsumedSubpartitionRange(
int consumerIndex,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 8796bc7..77f0e19 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -59,8 +59,10 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
import org.apache.flink.runtime.query.KvStateLocationRegistry;
import org.apache.flink.runtime.scheduler.InternalFailuresListener;
+import org.apache.flink.runtime.scheduler.SsgNetworkMemoryCalculationUtils;
import org.apache.flink.runtime.scheduler.VertexParallelismInformation;
import org.apache.flink.runtime.scheduler.VertexParallelismStore;
import org.apache.flink.runtime.scheduler.adapter.DefaultExecutionTopology;
@@ -854,6 +856,24 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
}
registerExecutionVerticesAndResultPartitionsFor(ejv);
+
+ // enrich network memory.
+ SlotSharingGroup slotSharingGroup = ejv.getSlotSharingGroup();
+ if (areJobVerticesAllInitialized(slotSharingGroup)) {
+ SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
+ slotSharingGroup, this::getJobVertex, shuffleMaster);
+ }
+ }
+
+ private boolean areJobVerticesAllInitialized(final SlotSharingGroup group) {
+ for (JobVertexID jobVertexId : group.getJobVertexIds()) {
+ final ExecutionJobVertex jobVertex = getJobVertex(jobVertexId);
+ checkNotNull(jobVertex, "Unknown job vertex %s", jobVertexId);
+ if (!jobVertex.isInitialized()) {
+ return false;
+ }
+ }
+ return true;
}
@Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index 508e974..4b666b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -168,7 +168,7 @@ public class IntermediateResult {
return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId));
}
- DistributionPattern getConsumingDistributionPattern() {
+ public DistributionPattern getConsumingDistributionPattern() {
final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
return consumer.getDistributionPattern();
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
index bd202ba..1a2fa36 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
@@ -173,8 +173,6 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
jobGraph.getName(),
jobGraph.getJobID());
- enrichResourceProfile();
-
this.executionFailureHandler =
new ExecutionFailureHandler(
getSchedulingTopology(), failoverStrategy, restartBackoffTimeStrategy);
@@ -723,13 +721,4 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
return reservedAllocationRefCounters.keySet();
}
}
-
- private void enrichResourceProfile() {
- Set<SlotSharingGroup> ssgs = new HashSet<>();
- getJobGraph().getVertices().forEach(jv -> ssgs.add(jv.getSlotSharingGroup()));
- ssgs.forEach(
- ssg ->
- SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
- ssg, this::getExecutionJobVertex, shuffleMaster));
- }
}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index 13c6172..0fac885 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -18,11 +18,16 @@
package org.apache.flink.runtime.scheduler;
+import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
+import org.apache.flink.runtime.deployment.SubpartitionIndexRange;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory;
import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -30,9 +35,11 @@ import org.apache.flink.runtime.jobgraph.JobEdge;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.runtime.shuffle.ShuffleMaster;
import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -51,7 +58,7 @@ public class SsgNetworkMemoryCalculationUtils {
* Calculates network memory requirement of {@link ExecutionJobVertex} and update {@link
* ResourceProfile} of corresponding slot sharing group.
*/
- static void enrichNetworkMemory(
+ public static void enrichNetworkMemory(
SlotSharingGroup ssg,
Function<JobVertexID, ExecutionJobVertex> ejvs,
ShuffleMaster<?> shuffleMaster) {
@@ -88,8 +95,17 @@ public class SsgNetworkMemoryCalculationUtils {
private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(
ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
- Map<IntermediateDataSetID, Integer> maxInputChannelNums = getMaxInputChannelNums(ejv);
- Map<IntermediateDataSetID, Integer> maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+ Map<IntermediateDataSetID, Integer> maxInputChannelNums;
+ Map<IntermediateDataSetID, Integer> maxSubpartitionNums;
+
+ if (ejv.getGraph().isDynamic()) {
+ maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv);
+ maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv);
+ } else {
+ maxInputChannelNums = getMaxInputChannelNums(ejv);
+ maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+ }
+
JobVertex jv = ejv.getJobVertex();
Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = getPartitionTypes(jv);
@@ -148,6 +164,47 @@ public class SsgNetworkMemoryCalculationUtils {
return ret;
}
+ @VisibleForTesting
+ static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph(
+ ExecutionJobVertex ejv) {
+
+ Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+
+ for (ExecutionVertex vertex : ejv.getTaskVertices()) {
+ for (ConsumedPartitionGroup partitionGroup : vertex.getAllConsumedPartitionGroups()) {
+
+ IntermediateResultPartition resultPartition =
+ ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
+ SubpartitionIndexRange subpartitionIndexRange =
+ TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
+ resultPartition, vertex.getParallelSubtaskIndex());
+
+ ret.merge(
+ partitionGroup.getIntermediateDataSetID(),
+ subpartitionIndexRange.size() * partitionGroup.size(),
+ Integer::max);
+ }
+ }
+
+ return ret;
+ }
+
+ private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph(
+ ExecutionJobVertex ejv) {
+
+ Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+
+ for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) {
+ final int maxNum =
+ Arrays.stream(intermediateResult.getPartitions())
+ .map(IntermediateResultPartition::getNumberOfSubpartitions)
+ .reduce(0, Integer::max);
+ ret.put(intermediateResult.getId(), maxNum);
+ }
+
+ return ret;
+ }
+
/** Private default constructor to avoid being instantiated. */
private SsgNetworkMemoryCalculationUtils() {}
}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index 186cbf5..240427b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -193,7 +193,7 @@ public class ExecutionJobVertexTest {
return createDynamicExecutionJobVertex(-1, -1, 1);
}
- private static ExecutionJobVertex createDynamicExecutionJobVertex(
+ public static ExecutionJobVertex createDynamicExecutionJobVertex(
int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception {
JobVertex jobVertex = new JobVertex("testVertex");
jobVertex.setInvokableClass(AbstractInvokable.class);
@@ -227,7 +227,7 @@ public class ExecutionJobVertexTest {
* @param defaultMaxParallelism the global default max parallelism
* @return the computed parallelism store
*/
- static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph(
+ public static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph(
Iterable<JobVertex> vertices, int defaultMaxParallelism) {
// for dynamic graph, there is no need to normalize vertex parallelism. if the max
// parallelism is not configured and the parallelism is a positive value, max
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 42f60b2..611e18a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -221,7 +221,7 @@ public class IntermediateResultPartitionTest extends TestLogger {
equalTo(expectedNumSubpartitions));
}
- private static ExecutionGraph createExecutionGraph(
+ public static ExecutionGraph createExecutionGraph(
int producerParallelism,
int consumerParallelism,
int consumerMaxParallelism,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
index 54ad8eb..9481c82 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
@@ -21,10 +21,16 @@ package org.apache.flink.runtime.scheduler;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.MemorySize;
import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
-import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertexTest;
+import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest;
import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -39,9 +45,13 @@ import org.apache.flink.runtime.testtasks.NoOpInvokable;
import org.junit.Test;
import java.util.Arrays;
+import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
/** Tests for {@link SsgNetworkMemoryCalculationUtils}. */
@@ -51,81 +61,219 @@ public class SsgNetworkMemoryCalculationUtilsTest {
private static final ResourceProfile DEFAULT_RESOURCE = ResourceProfile.fromResources(1.0, 100);
- private JobGraph jobGraph;
-
- private ExecutionGraph executionGraph;
-
- private List<SlotSharingGroup> slotSharingGroups;
-
@Test
public void testGenerateEnrichedResourceProfile() throws Exception {
- setup(DEFAULT_RESOURCE);
- slotSharingGroups.forEach(
- ssg ->
- SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
- ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER));
+ SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
+ slotSharingGroup0.setResourceProfile(DEFAULT_RESOURCE);
+
+ SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
+ slotSharingGroup1.setResourceProfile(DEFAULT_RESOURCE);
+
+ createExecutionGraphAndEnrichNetworkMemory(
+ Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1));
assertEquals(
new MemorySize(
TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2)
+ TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 6)),
- slotSharingGroups.get(0).getResourceProfile().getNetworkMemory());
-
+ slotSharingGroup0.getResourceProfile().getNetworkMemory());
assertEquals(
new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 0)),
- slotSharingGroups.get(1).getResourceProfile().getNetworkMemory());
+ slotSharingGroup1.getResourceProfile().getNetworkMemory());
}
@Test
public void testGenerateUnknownResourceProfile() throws Exception {
- setup(ResourceProfile.UNKNOWN);
+ SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
+ slotSharingGroup0.setResourceProfile(ResourceProfile.UNKNOWN);
+
+ SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
+ slotSharingGroup1.setResourceProfile(ResourceProfile.UNKNOWN);
+
+ createExecutionGraphAndEnrichNetworkMemory(
+ Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1));
+
+ assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup0.getResourceProfile());
+ assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup1.getResourceProfile());
+ }
+
+ @Test
+ public void testGenerateEnrichedResourceProfileForDynamicGraph() throws Exception {
+ List<SlotSharingGroup> slotSharingGroups =
+ Arrays.asList(
+ new SlotSharingGroup(), new SlotSharingGroup(), new SlotSharingGroup());
+
+ for (SlotSharingGroup group : slotSharingGroups) {
+ group.setResourceProfile(DEFAULT_RESOURCE);
+ }
+
+ DefaultExecutionGraph executionGraph = createDynamicExecutionGraph(slotSharingGroups, 20);
+ Iterator<ExecutionJobVertex> jobVertices =
+ executionGraph.getVerticesTopologically().iterator();
+ ExecutionJobVertex source = jobVertices.next();
+ ExecutionJobVertex map = jobVertices.next();
+ ExecutionJobVertex sink = jobVertices.next();
- slotSharingGroups.forEach(
- ssg ->
- SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
- ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER));
+ executionGraph.initializeJobVertex(source, 0L);
+ triggerComputeNumOfSubpartitions(source.getProducedDataSets()[0]);
+
+ map.setParallelism(5);
+ executionGraph.initializeJobVertex(map, 0L);
+ triggerComputeNumOfSubpartitions(map.getProducedDataSets()[0]);
+
+ sink.setParallelism(7);
+ executionGraph.initializeJobVertex(sink, 0L);
+
+ assertNetworkMemory(
+ slotSharingGroups,
+ Arrays.asList(
+ new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 5)),
+ new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 20)),
+ new MemorySize(
+ TestShuffleMaster.computeRequiredShuffleMemoryBytes(15, 0))));
+ }
- for (SlotSharingGroup slotSharingGroup : slotSharingGroups) {
- assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup.getResourceProfile());
+ private void triggerComputeNumOfSubpartitions(IntermediateResult result) {
+ // call IntermediateResultPartition#getNumberOfSubpartitions to trigger computation of
+ // numOfSubpartitions
+ for (IntermediateResultPartition partition : result.getPartitions()) {
+ partition.getNumberOfSubpartitions();
}
}
- private void setup(final ResourceProfile resourceProfile) throws Exception {
- slotSharingGroups = Arrays.asList(new SlotSharingGroup(), new SlotSharingGroup());
+ private void assertNetworkMemory(
+ List<SlotSharingGroup> slotSharingGroups, List<MemorySize> networkMemory) {
- for (SlotSharingGroup slotSharingGroup : slotSharingGroups) {
- slotSharingGroup.setResourceProfile(resourceProfile);
+ assertEquals(slotSharingGroups.size(), networkMemory.size());
+ for (int i = 0; i < slotSharingGroups.size(); ++i) {
+ assertThat(
+ slotSharingGroups.get(i).getResourceProfile().getNetworkMemory(),
+ is(networkMemory.get(i)));
}
+ }
+
+ @Test
+ public void testGetMaxInputChannelNumForResultForAllToAll() throws Exception {
+ testGetMaxInputChannelNumForResult(DistributionPattern.ALL_TO_ALL, 5, 20, 7, 15);
+ }
- jobGraph = createJobGraph(slotSharingGroups);
- executionGraph =
- TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).build();
+ @Test
+ public void testGetMaxInputChannelNumForResultForPointWise() throws Exception {
+ testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 3, 8);
+ testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 5, 4);
+ testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 7, 4);
}
- private static JobGraph createJobGraph(final List<SlotSharingGroup> slotSharingGroups) {
+ private void testGetMaxInputChannelNumForResult(
+ DistributionPattern distributionPattern,
+ int producerParallelism,
+ int consumerMaxParallelism,
+ int decidedConsumerParallelism,
+ int expectedNumChannels)
+ throws Exception {
+
+ final DefaultExecutionGraph eg =
+ (DefaultExecutionGraph)
+ IntermediateResultPartitionTest.createExecutionGraph(
+ producerParallelism,
+ -1,
+ consumerMaxParallelism,
+ distributionPattern,
+ true);
+
+ final Iterator<ExecutionJobVertex> vertexIterator =
+ eg.getVerticesTopologically().iterator();
+ final ExecutionJobVertex producer = vertexIterator.next();
+ final ExecutionJobVertex consumer = vertexIterator.next();
+
+ eg.initializeJobVertex(producer, 0L);
+ final IntermediateResult result = producer.getProducedDataSets()[0];
+ triggerComputeNumOfSubpartitions(result);
+
+ consumer.setParallelism(decidedConsumerParallelism);
+ eg.initializeJobVertex(consumer, 0L);
+
+ Map<IntermediateDataSetID, Integer> maxInputChannelNums =
+ SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer);
+
+ assertThat(maxInputChannelNums.size(), is(1));
+ assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels));
+ }
+
+ private DefaultExecutionGraph createDynamicExecutionGraph(
+ final List<SlotSharingGroup> slotSharingGroups, int defaultMaxParallelism)
+ throws Exception {
+
+ JobGraph jobGraph = createBatchGraph(slotSharingGroups, Arrays.asList(4, -1, -1));
+
+ final VertexParallelismStore vertexParallelismStore =
+ ExecutionJobVertexTest.computeVertexParallelismStoreForDynamicGraph(
+ jobGraph.getVertices(), defaultMaxParallelism);
+
+ return TestingDefaultExecutionGraphBuilder.newBuilder()
+ .setJobGraph(jobGraph)
+ .setVertexParallelismStore(vertexParallelismStore)
+ .setShuffleMaster(SHUFFLE_MASTER)
+ .buildDynamicGraph();
+ }
+
+ private void createExecutionGraphAndEnrichNetworkMemory(
+ final List<SlotSharingGroup> slotSharingGroups) throws Exception {
+ TestingDefaultExecutionGraphBuilder.newBuilder()
+ .setJobGraph(createStreamingGraph(slotSharingGroups, Arrays.asList(4, 5, 6)))
+ .setShuffleMaster(SHUFFLE_MASTER)
+ .build();
+ }
+
+ private static JobGraph createStreamingGraph(
+ final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) {
+ return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.PIPELINED);
+ }
+
+ private static JobGraph createBatchGraph(
+ final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) {
+ return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.BLOCKING);
+ }
+
+ private static JobGraph createJobGraph(
+ final List<SlotSharingGroup> slotSharingGroups,
+ List<Integer> parallelisms,
+ ResultPartitionType resultPartitionType) {
+
+ assertThat(slotSharingGroups.size(), is(3));
+ assertThat(parallelisms.size(), is(3));
JobVertex source = new JobVertex("source");
source.setInvokableClass(NoOpInvokable.class);
- source.setParallelism(4);
+ trySetParallelism(source, parallelisms.get(0));
source.setSlotSharingGroup(slotSharingGroups.get(0));
JobVertex map = new JobVertex("map");
map.setInvokableClass(NoOpInvokable.class);
- map.setParallelism(5);
- map.setSlotSharingGroup(slotSharingGroups.get(0));
+ trySetParallelism(map, parallelisms.get(1));
+ map.setSlotSharingGroup(slotSharingGroups.get(1));
JobVertex sink = new JobVertex("sink");
sink.setInvokableClass(NoOpInvokable.class);
- sink.setParallelism(6);
- sink.setSlotSharingGroup(slotSharingGroups.get(1));
+ trySetParallelism(sink, parallelisms.get(2));
+ sink.setSlotSharingGroup(slotSharingGroups.get(2));
+
+ map.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, resultPartitionType);
+ sink.connectNewDataSetAsInput(map, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+
+ if (resultPartitionType.isPipelined()) {
+ return JobGraphTestUtils.streamingJobGraph(source, map, sink);
- map.connectNewDataSetAsInput(
- source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
- sink.connectNewDataSetAsInput(
- map, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+ } else {
+ return JobGraphTestUtils.batchJobGraph(source, map, sink);
+ }
+ }
- return JobGraphTestUtils.streamingJobGraph(source, map, sink);
+ private static void trySetParallelism(JobVertex jobVertex, int parallelism) {
+ if (parallelism > 0) {
+ jobVertex.setParallelism(parallelism);
+ }
}
private static class TestShuffleMaster implements ShuffleMaster<ShuffleDescriptor> {