You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/01/25 18:34:52 UTC

[flink] branch master updated: [FLINK-25668][runtime] Support to compute network memory for dynamic graph.

This is an automated email from the ASF dual-hosted git repository.

zhuzh pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 7962a5d  [FLINK-25668][runtime] Support to compute network memory for dynamic graph.
7962a5d is described below

commit 7962a5d9c47f49f435f822a1af2c4141c42a849b
Author: Lijie Wang <wa...@gmail.com>
AuthorDate: Tue Dec 14 22:04:20 2021 +0800

    [FLINK-25668][runtime] Support to compute network memory for dynamic graph.
    
    This closes #18376.
---
 .../runtime/deployment/SubpartitionIndexRange.java |   4 +
 .../TaskDeploymentDescriptorFactory.java           |  24 ++-
 .../executiongraph/DefaultExecutionGraph.java      |  20 ++
 .../runtime/executiongraph/IntermediateResult.java |   2 +-
 .../flink/runtime/scheduler/DefaultScheduler.java  |  11 -
 .../SsgNetworkMemoryCalculationUtils.java          |  63 +++++-
 .../executiongraph/ExecutionJobVertexTest.java     |   4 +-
 .../IntermediateResultPartitionTest.java           |   2 +-
 .../SsgNetworkMemoryCalculationUtilsTest.java      | 228 +++++++++++++++++----
 9 files changed, 291 insertions(+), 67 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
index 1fb1d52..19484a6 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/SubpartitionIndexRange.java
@@ -43,6 +43,10 @@ public class SubpartitionIndexRange implements Serializable {
         return endIndex;
     }
 
+    public int size() {
+        return endIndex - startIndex + 1;
+    }
+
     @Override
     public String toString() {
         return String.format("[%d, %d]", startIndex, endIndex);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
index 528a954..bd0f5b3 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorFactory.java
@@ -128,17 +128,9 @@ public class TaskDeploymentDescriptorFactory {
             IntermediateResultPartition resultPartition =
                     resultPartitionRetriever.apply(consumedPartitionGroup.getFirst());
 
-            int numConsumers = resultPartition.getConsumerVertexGroup().size();
             IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
-            int consumerIndex = subtaskIndex % numConsumers;
-            int numSubpartitions = resultPartition.getNumberOfSubpartitions();
             SubpartitionIndexRange consumedSubpartitionRange =
-                    computeConsumedSubpartitionRange(
-                            consumerIndex,
-                            numConsumers,
-                            numSubpartitions,
-                            consumedIntermediateResult.getProducer().getGraph().isDynamic(),
-                            consumedIntermediateResult.isBroadcast());
+                    computeConsumedSubpartitionRange(resultPartition, subtaskIndex);
 
             IntermediateDataSetID resultId = consumedIntermediateResult.getId();
             ResultPartitionType partitionType = consumedIntermediateResult.getResultType();
@@ -155,6 +147,20 @@ public class TaskDeploymentDescriptorFactory {
         return inputGates;
     }
 
+    public static SubpartitionIndexRange computeConsumedSubpartitionRange(
+            IntermediateResultPartition resultPartition, int consumerSubtaskIndex) {
+        int numConsumers = resultPartition.getConsumerVertexGroup().size();
+        int consumerIndex = consumerSubtaskIndex % numConsumers;
+        IntermediateResult consumedIntermediateResult = resultPartition.getIntermediateResult();
+        int numSubpartitions = resultPartition.getNumberOfSubpartitions();
+        return computeConsumedSubpartitionRange(
+                consumerIndex,
+                numConsumers,
+                numSubpartitions,
+                consumedIntermediateResult.getProducer().getGraph().isDynamic(),
+                consumedIntermediateResult.isBroadcast());
+    }
+
     @VisibleForTesting
     static SubpartitionIndexRange computeConsumedSubpartitionRange(
             int consumerIndex,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
index 8796bc7..77f0e19 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/DefaultExecutionGraph.java
@@ -59,8 +59,10 @@ import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.CheckpointCoordinatorConfiguration;
+import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
 import org.apache.flink.runtime.query.KvStateLocationRegistry;
 import org.apache.flink.runtime.scheduler.InternalFailuresListener;
+import org.apache.flink.runtime.scheduler.SsgNetworkMemoryCalculationUtils;
 import org.apache.flink.runtime.scheduler.VertexParallelismInformation;
 import org.apache.flink.runtime.scheduler.VertexParallelismStore;
 import org.apache.flink.runtime.scheduler.adapter.DefaultExecutionTopology;
@@ -854,6 +856,24 @@ public class DefaultExecutionGraph implements ExecutionGraph, InternalExecutionG
         }
 
         registerExecutionVerticesAndResultPartitionsFor(ejv);
+
+        // enrich network memory.
+        SlotSharingGroup slotSharingGroup = ejv.getSlotSharingGroup();
+        if (areJobVerticesAllInitialized(slotSharingGroup)) {
+            SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
+                    slotSharingGroup, this::getJobVertex, shuffleMaster);
+        }
+    }
+
+    private boolean areJobVerticesAllInitialized(final SlotSharingGroup group) {
+        for (JobVertexID jobVertexId : group.getJobVertexIds()) {
+            final ExecutionJobVertex jobVertex = getJobVertex(jobVertexId);
+            checkNotNull(jobVertex, "Unknown job vertex %s", jobVertexId);
+            if (!jobVertex.isInitialized()) {
+                return false;
+            }
+        }
+        return true;
     }
 
     @Override
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
index 508e974..4b666b5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/executiongraph/IntermediateResult.java
@@ -168,7 +168,7 @@ public class IntermediateResult {
         return checkNotNull(getProducer().getGraph().getJobVertex(consumerJobVertexId));
     }
 
-    DistributionPattern getConsumingDistributionPattern() {
+    public DistributionPattern getConsumingDistributionPattern() {
         final JobEdge consumer = checkNotNull(intermediateDataSet.getConsumer());
         return consumer.getDistributionPattern();
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
index bd202ba..1a2fa36 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/DefaultScheduler.java
@@ -173,8 +173,6 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
                 jobGraph.getName(),
                 jobGraph.getJobID());
 
-        enrichResourceProfile();
-
         this.executionFailureHandler =
                 new ExecutionFailureHandler(
                         getSchedulingTopology(), failoverStrategy, restartBackoffTimeStrategy);
@@ -723,13 +721,4 @@ public class DefaultScheduler extends SchedulerBase implements SchedulerOperatio
             return reservedAllocationRefCounters.keySet();
         }
     }
-
-    private void enrichResourceProfile() {
-        Set<SlotSharingGroup> ssgs = new HashSet<>();
-        getJobGraph().getVertices().forEach(jv -> ssgs.add(jv.getSlotSharingGroup()));
-        ssgs.forEach(
-                ssg ->
-                        SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
-                                ssg, this::getExecutionJobVertex, shuffleMaster));
-    }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index 13c6172..0fac885 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -18,11 +18,16 @@
 
 package org.apache.flink.runtime.scheduler;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
+import org.apache.flink.runtime.deployment.SubpartitionIndexRange;
+import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory;
 import org.apache.flink.runtime.executiongraph.EdgeManagerBuildUtil;
 import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSet;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
@@ -30,9 +35,11 @@ import org.apache.flink.runtime.jobgraph.JobEdge;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.scheduler.SlotSharingGroup;
+import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
 import org.apache.flink.runtime.shuffle.ShuffleMaster;
 import org.apache.flink.runtime.shuffle.TaskInputsOutputsDescriptor;
 
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -51,7 +58,7 @@ public class SsgNetworkMemoryCalculationUtils {
      * Calculates network memory requirement of {@link ExecutionJobVertex} and update {@link
      * ResourceProfile} of corresponding slot sharing group.
      */
-    static void enrichNetworkMemory(
+    public static void enrichNetworkMemory(
             SlotSharingGroup ssg,
             Function<JobVertexID, ExecutionJobVertex> ejvs,
             ShuffleMaster<?> shuffleMaster) {
@@ -88,8 +95,17 @@ public class SsgNetworkMemoryCalculationUtils {
     private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(
             ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
 
-        Map<IntermediateDataSetID, Integer> maxInputChannelNums = getMaxInputChannelNums(ejv);
-        Map<IntermediateDataSetID, Integer> maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums;
+        Map<IntermediateDataSetID, Integer> maxSubpartitionNums;
+
+        if (ejv.getGraph().isDynamic()) {
+            maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv);
+            maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv);
+        } else {
+            maxInputChannelNums = getMaxInputChannelNums(ejv);
+            maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+        }
+
         JobVertex jv = ejv.getJobVertex();
         Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = getPartitionTypes(jv);
 
@@ -148,6 +164,47 @@ public class SsgNetworkMemoryCalculationUtils {
         return ret;
     }
 
+    @VisibleForTesting
+    static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph(
+            ExecutionJobVertex ejv) {
+
+        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+
+        for (ExecutionVertex vertex : ejv.getTaskVertices()) {
+            for (ConsumedPartitionGroup partitionGroup : vertex.getAllConsumedPartitionGroups()) {
+
+                IntermediateResultPartition resultPartition =
+                        ejv.getGraph().getResultPartitionOrThrow((partitionGroup.getFirst()));
+                SubpartitionIndexRange subpartitionIndexRange =
+                        TaskDeploymentDescriptorFactory.computeConsumedSubpartitionRange(
+                                resultPartition, vertex.getParallelSubtaskIndex());
+
+                ret.merge(
+                        partitionGroup.getIntermediateDataSetID(),
+                        subpartitionIndexRange.size() * partitionGroup.size(),
+                        Integer::max);
+            }
+        }
+
+        return ret;
+    }
+
+    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph(
+            ExecutionJobVertex ejv) {
+
+        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+
+        for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) {
+            final int maxNum =
+                    Arrays.stream(intermediateResult.getPartitions())
+                            .map(IntermediateResultPartition::getNumberOfSubpartitions)
+                            .reduce(0, Integer::max);
+            ret.put(intermediateResult.getId(), maxNum);
+        }
+
+        return ret;
+    }
+
     /** Private default constructor to avoid being instantiated. */
     private SsgNetworkMemoryCalculationUtils() {}
 }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
index 186cbf5..240427b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionJobVertexTest.java
@@ -193,7 +193,7 @@ public class ExecutionJobVertexTest {
         return createDynamicExecutionJobVertex(-1, -1, 1);
     }
 
-    private static ExecutionJobVertex createDynamicExecutionJobVertex(
+    public static ExecutionJobVertex createDynamicExecutionJobVertex(
             int parallelism, int maxParallelism, int defaultMaxParallelism) throws Exception {
         JobVertex jobVertex = new JobVertex("testVertex");
         jobVertex.setInvokableClass(AbstractInvokable.class);
@@ -227,7 +227,7 @@ public class ExecutionJobVertexTest {
      * @param defaultMaxParallelism the global default max parallelism
      * @return the computed parallelism store
      */
-    static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph(
+    public static VertexParallelismStore computeVertexParallelismStoreForDynamicGraph(
             Iterable<JobVertex> vertices, int defaultMaxParallelism) {
         // for dynamic graph, there is no need to normalize vertex parallelism. if the max
         // parallelism is not configured and the parallelism is a positive value, max
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
index 42f60b2..611e18a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/IntermediateResultPartitionTest.java
@@ -221,7 +221,7 @@ public class IntermediateResultPartitionTest extends TestLogger {
                 equalTo(expectedNumSubpartitions));
     }
 
-    private static ExecutionGraph createExecutionGraph(
+    public static ExecutionGraph createExecutionGraph(
             int producerParallelism,
             int consumerParallelism,
             int consumerMaxParallelism,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
index 54ad8eb..9481c82 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
@@ -21,10 +21,16 @@ package org.apache.flink.runtime.scheduler;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.MemorySize;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
-import org.apache.flink.runtime.executiongraph.ExecutionGraph;
+import org.apache.flink.runtime.executiongraph.DefaultExecutionGraph;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
+import org.apache.flink.runtime.executiongraph.ExecutionJobVertexTest;
+import org.apache.flink.runtime.executiongraph.IntermediateResult;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
+import org.apache.flink.runtime.executiongraph.IntermediateResultPartitionTest;
 import org.apache.flink.runtime.executiongraph.TestingDefaultExecutionGraphBuilder;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
 import org.apache.flink.runtime.jobgraph.DistributionPattern;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobGraphTestUtils;
 import org.apache.flink.runtime.jobgraph.JobVertex;
@@ -39,9 +45,13 @@ import org.apache.flink.runtime.testtasks.NoOpInvokable;
 import org.junit.Test;
 
 import java.util.Arrays;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 
+import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.MatcherAssert.assertThat;
 import static org.junit.Assert.assertEquals;
 
 /** Tests for {@link SsgNetworkMemoryCalculationUtils}. */
@@ -51,81 +61,219 @@ public class SsgNetworkMemoryCalculationUtilsTest {
 
     private static final ResourceProfile DEFAULT_RESOURCE = ResourceProfile.fromResources(1.0, 100);
 
-    private JobGraph jobGraph;
-
-    private ExecutionGraph executionGraph;
-
-    private List<SlotSharingGroup> slotSharingGroups;
-
     @Test
     public void testGenerateEnrichedResourceProfile() throws Exception {
-        setup(DEFAULT_RESOURCE);
 
-        slotSharingGroups.forEach(
-                ssg ->
-                        SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
-                                ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER));
+        SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
+        slotSharingGroup0.setResourceProfile(DEFAULT_RESOURCE);
+
+        SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
+        slotSharingGroup1.setResourceProfile(DEFAULT_RESOURCE);
+
+        createExecutionGraphAndEnrichNetworkMemory(
+                Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1));
 
         assertEquals(
                 new MemorySize(
                         TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2)
                                 + TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 6)),
-                slotSharingGroups.get(0).getResourceProfile().getNetworkMemory());
-
+                slotSharingGroup0.getResourceProfile().getNetworkMemory());
         assertEquals(
                 new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 0)),
-                slotSharingGroups.get(1).getResourceProfile().getNetworkMemory());
+                slotSharingGroup1.getResourceProfile().getNetworkMemory());
     }
 
     @Test
     public void testGenerateUnknownResourceProfile() throws Exception {
-        setup(ResourceProfile.UNKNOWN);
+        SlotSharingGroup slotSharingGroup0 = new SlotSharingGroup();
+        slotSharingGroup0.setResourceProfile(ResourceProfile.UNKNOWN);
+
+        SlotSharingGroup slotSharingGroup1 = new SlotSharingGroup();
+        slotSharingGroup1.setResourceProfile(ResourceProfile.UNKNOWN);
+
+        createExecutionGraphAndEnrichNetworkMemory(
+                Arrays.asList(slotSharingGroup0, slotSharingGroup0, slotSharingGroup1));
+
+        assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup0.getResourceProfile());
+        assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup1.getResourceProfile());
+    }
+
+    @Test
+    public void testGenerateEnrichedResourceProfileForDynamicGraph() throws Exception {
+        List<SlotSharingGroup> slotSharingGroups =
+                Arrays.asList(
+                        new SlotSharingGroup(), new SlotSharingGroup(), new SlotSharingGroup());
+
+        for (SlotSharingGroup group : slotSharingGroups) {
+            group.setResourceProfile(DEFAULT_RESOURCE);
+        }
+
+        DefaultExecutionGraph executionGraph = createDynamicExecutionGraph(slotSharingGroups, 20);
+        Iterator<ExecutionJobVertex> jobVertices =
+                executionGraph.getVerticesTopologically().iterator();
+        ExecutionJobVertex source = jobVertices.next();
+        ExecutionJobVertex map = jobVertices.next();
+        ExecutionJobVertex sink = jobVertices.next();
 
-        slotSharingGroups.forEach(
-                ssg ->
-                        SsgNetworkMemoryCalculationUtils.enrichNetworkMemory(
-                                ssg, executionGraph.getAllVertices()::get, SHUFFLE_MASTER));
+        executionGraph.initializeJobVertex(source, 0L);
+        triggerComputeNumOfSubpartitions(source.getProducedDataSets()[0]);
+
+        map.setParallelism(5);
+        executionGraph.initializeJobVertex(map, 0L);
+        triggerComputeNumOfSubpartitions(map.getProducedDataSets()[0]);
+
+        sink.setParallelism(7);
+        executionGraph.initializeJobVertex(sink, 0L);
+
+        assertNetworkMemory(
+                slotSharingGroups,
+                Arrays.asList(
+                        new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 5)),
+                        new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 20)),
+                        new MemorySize(
+                                TestShuffleMaster.computeRequiredShuffleMemoryBytes(15, 0))));
+    }
 
-        for (SlotSharingGroup slotSharingGroup : slotSharingGroups) {
-            assertEquals(ResourceProfile.UNKNOWN, slotSharingGroup.getResourceProfile());
+    private void triggerComputeNumOfSubpartitions(IntermediateResult result) {
+        // call IntermediateResultPartition#getNumberOfSubpartitions to trigger computation of
+        // numOfSubpartitions
+        for (IntermediateResultPartition partition : result.getPartitions()) {
+            partition.getNumberOfSubpartitions();
         }
     }
 
-    private void setup(final ResourceProfile resourceProfile) throws Exception {
-        slotSharingGroups = Arrays.asList(new SlotSharingGroup(), new SlotSharingGroup());
+    private void assertNetworkMemory(
+            List<SlotSharingGroup> slotSharingGroups, List<MemorySize> networkMemory) {
 
-        for (SlotSharingGroup slotSharingGroup : slotSharingGroups) {
-            slotSharingGroup.setResourceProfile(resourceProfile);
+        assertEquals(slotSharingGroups.size(), networkMemory.size());
+        for (int i = 0; i < slotSharingGroups.size(); ++i) {
+            assertThat(
+                    slotSharingGroups.get(i).getResourceProfile().getNetworkMemory(),
+                    is(networkMemory.get(i)));
         }
+    }
+
+    @Test
+    public void testGetMaxInputChannelNumForResultForAllToAll() throws Exception {
+        testGetMaxInputChannelNumForResult(DistributionPattern.ALL_TO_ALL, 5, 20, 7, 15);
+    }
 
-        jobGraph = createJobGraph(slotSharingGroups);
-        executionGraph =
-                TestingDefaultExecutionGraphBuilder.newBuilder().setJobGraph(jobGraph).build();
+    @Test
+    public void testGetMaxInputChannelNumForResultForPointWise() throws Exception {
+        testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 3, 8);
+        testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 5, 4);
+        testGetMaxInputChannelNumForResult(DistributionPattern.POINTWISE, 5, 20, 7, 4);
     }
 
-    private static JobGraph createJobGraph(final List<SlotSharingGroup> slotSharingGroups) {
+    private void testGetMaxInputChannelNumForResult(
+            DistributionPattern distributionPattern,
+            int producerParallelism,
+            int consumerMaxParallelism,
+            int decidedConsumerParallelism,
+            int expectedNumChannels)
+            throws Exception {
+
+        final DefaultExecutionGraph eg =
+                (DefaultExecutionGraph)
+                        IntermediateResultPartitionTest.createExecutionGraph(
+                                producerParallelism,
+                                -1,
+                                consumerMaxParallelism,
+                                distributionPattern,
+                                true);
+
+        final Iterator<ExecutionJobVertex> vertexIterator =
+                eg.getVerticesTopologically().iterator();
+        final ExecutionJobVertex producer = vertexIterator.next();
+        final ExecutionJobVertex consumer = vertexIterator.next();
+
+        eg.initializeJobVertex(producer, 0L);
+        final IntermediateResult result = producer.getProducedDataSets()[0];
+        triggerComputeNumOfSubpartitions(result);
+
+        consumer.setParallelism(decidedConsumerParallelism);
+        eg.initializeJobVertex(consumer, 0L);
+
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums =
+                SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer);
+
+        assertThat(maxInputChannelNums.size(), is(1));
+        assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels));
+    }
+
+    private DefaultExecutionGraph createDynamicExecutionGraph(
+            final List<SlotSharingGroup> slotSharingGroups, int defaultMaxParallelism)
+            throws Exception {
+
+        JobGraph jobGraph = createBatchGraph(slotSharingGroups, Arrays.asList(4, -1, -1));
+
+        final VertexParallelismStore vertexParallelismStore =
+                ExecutionJobVertexTest.computeVertexParallelismStoreForDynamicGraph(
+                        jobGraph.getVertices(), defaultMaxParallelism);
+
+        return TestingDefaultExecutionGraphBuilder.newBuilder()
+                .setJobGraph(jobGraph)
+                .setVertexParallelismStore(vertexParallelismStore)
+                .setShuffleMaster(SHUFFLE_MASTER)
+                .buildDynamicGraph();
+    }
+
+    private void createExecutionGraphAndEnrichNetworkMemory(
+            final List<SlotSharingGroup> slotSharingGroups) throws Exception {
+        TestingDefaultExecutionGraphBuilder.newBuilder()
+                .setJobGraph(createStreamingGraph(slotSharingGroups, Arrays.asList(4, 5, 6)))
+                .setShuffleMaster(SHUFFLE_MASTER)
+                .build();
+    }
+
+    private static JobGraph createStreamingGraph(
+            final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) {
+        return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.PIPELINED);
+    }
+
+    private static JobGraph createBatchGraph(
+            final List<SlotSharingGroup> slotSharingGroups, List<Integer> parallelisms) {
+        return createJobGraph(slotSharingGroups, parallelisms, ResultPartitionType.BLOCKING);
+    }
+
+    private static JobGraph createJobGraph(
+            final List<SlotSharingGroup> slotSharingGroups,
+            List<Integer> parallelisms,
+            ResultPartitionType resultPartitionType) {
+
+        assertThat(slotSharingGroups.size(), is(3));
+        assertThat(parallelisms.size(), is(3));
 
         JobVertex source = new JobVertex("source");
         source.setInvokableClass(NoOpInvokable.class);
-        source.setParallelism(4);
+        trySetParallelism(source, parallelisms.get(0));
         source.setSlotSharingGroup(slotSharingGroups.get(0));
 
         JobVertex map = new JobVertex("map");
         map.setInvokableClass(NoOpInvokable.class);
-        map.setParallelism(5);
-        map.setSlotSharingGroup(slotSharingGroups.get(0));
+        trySetParallelism(map, parallelisms.get(1));
+        map.setSlotSharingGroup(slotSharingGroups.get(1));
 
         JobVertex sink = new JobVertex("sink");
         sink.setInvokableClass(NoOpInvokable.class);
-        sink.setParallelism(6);
-        sink.setSlotSharingGroup(slotSharingGroups.get(1));
+        trySetParallelism(sink, parallelisms.get(2));
+        sink.setSlotSharingGroup(slotSharingGroups.get(2));
+
+        map.connectNewDataSetAsInput(source, DistributionPattern.POINTWISE, resultPartitionType);
+        sink.connectNewDataSetAsInput(map, DistributionPattern.ALL_TO_ALL, resultPartitionType);
+
+        if (resultPartitionType.isPipelined()) {
+            return JobGraphTestUtils.streamingJobGraph(source, map, sink);
 
-        map.connectNewDataSetAsInput(
-                source, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
-        sink.connectNewDataSetAsInput(
-                map, DistributionPattern.ALL_TO_ALL, ResultPartitionType.PIPELINED);
+        } else {
+            return JobGraphTestUtils.batchJobGraph(source, map, sink);
+        }
+    }
 
-        return JobGraphTestUtils.streamingJobGraph(source, map, sink);
+    private static void trySetParallelism(JobVertex jobVertex, int parallelism) {
+        if (parallelism > 0) {
+            jobVertex.setParallelism(parallelism);
+        }
     }
 
     private static class TestShuffleMaster implements ShuffleMaster<ShuffleDescriptor> {