You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by xt...@apache.org on 2023/01/17 02:18:24 UTC

[flink] branch master updated (d8417565d6f -> ae8de97ef2a)

This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a change to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


    from d8417565d6f [FLINK-28526][python] Fix Python UDF to support time indicator inputs
     new fc7defb14d1 [FLINK-30471][network] Optimize the enriching network memory process in SsgNetworkMemoryCalculationUtils
     new afdf4a73e43 [FLINK-30472][network] Modify the default value of the max network memory config option
     new ae8de97ef2a [FLINK-30473][network] Optimize the InputGate network memory management for TaskManager

The 3 revisions listed above as "new" are entirely new to this
repository and will be described in separate emails.  The revisions
listed as "add" were already present in the repository and have only
been added to this reference.


Summary of changes:
 .../generated/all_taskmanager_network_section.html |   6 +
 .../generated/common_memory_section.html           |   4 +-
 .../netty_shuffle_environment_configuration.html   |   6 +
 .../task_manager_memory_configuration.html         |   4 +-
 .../NettyShuffleEnvironmentOptions.java            |  29 +++
 .../flink/configuration/TaskManagerOptions.java    |   6 +-
 .../partition/consumer/GateBuffersSpec.java        |  61 +++++++
 .../partition/consumer/InputGateSpecUitls.java     | 125 +++++++++++++
 .../partition/consumer/SingleInputGateFactory.java |  45 +++--
 .../SsgNetworkMemoryCalculationUtils.java          |  99 +++++-----
 .../flink/runtime/shuffle/NettyShuffleMaster.java  |  35 +++-
 .../flink/runtime/shuffle/NettyShuffleUtils.java   |  52 ++++--
 .../shuffle/TaskInputsOutputsDescriptor.java       |  29 ++-
 .../NettyShuffleEnvironmentConfiguration.java      |  29 +++
 .../io/network/NettyShuffleEnvironmentBuilder.java |  10 ++
 .../partition/consumer/GateBuffersSpecTest.java    | 200 +++++++++++++++++++++
 .../partition/consumer/SingleInputGateBuilder.java |   4 +-
 .../partition/consumer/SingleInputGateTest.java    |  63 +++++++
 .../SsgNetworkMemoryCalculationUtilsTest.java      |  13 +-
 .../runtime/shuffle/NettyShuffleUtilsTest.java     |  16 +-
 .../benchmark/SingleInputGateBenchmarkFactory.java |   3 +-
 21 files changed, 746 insertions(+), 93 deletions(-)
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpec.java
 create mode 100644 flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateSpecUitls.java
 create mode 100644 flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpecTest.java


[flink] 03/03: [FLINK-30473][network] Optimize the InputGate network memory management for TaskManager

Posted by xt...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit ae8de97ef2acc798dae34ad2096ece8886bcf308
Author: Yuxin Tan <ta...@gmail.com>
AuthorDate: Mon Dec 12 17:15:41 2022 +0800

    [FLINK-30473][network] Optimize the InputGate network memory management for TaskManager
    
    This closes #21620
---
 .../generated/all_taskmanager_network_section.html |   6 +
 .../netty_shuffle_environment_configuration.html   |   6 +
 .../NettyShuffleEnvironmentOptions.java            |  29 +++
 .../partition/consumer/GateBuffersSpec.java        |  61 +++++++
 .../partition/consumer/InputGateSpecUitls.java     | 125 +++++++++++++
 .../partition/consumer/SingleInputGateFactory.java |  45 +++--
 .../SsgNetworkMemoryCalculationUtils.java          |  46 +++--
 .../flink/runtime/shuffle/NettyShuffleMaster.java  |  35 +++-
 .../flink/runtime/shuffle/NettyShuffleUtils.java   |  52 ++++--
 .../shuffle/TaskInputsOutputsDescriptor.java       |  29 ++-
 .../NettyShuffleEnvironmentConfiguration.java      |  29 +++
 .../io/network/NettyShuffleEnvironmentBuilder.java |  10 ++
 .../partition/consumer/GateBuffersSpecTest.java    | 200 +++++++++++++++++++++
 .../partition/consumer/SingleInputGateBuilder.java |   4 +-
 .../partition/consumer/SingleInputGateTest.java    |  63 +++++++
 .../SsgNetworkMemoryCalculationUtilsTest.java      |   9 +-
 .../runtime/shuffle/NettyShuffleUtilsTest.java     |  16 +-
 .../benchmark/SingleInputGateBenchmarkFactory.java |   3 +-
 18 files changed, 716 insertions(+), 52 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
index d4d5b2bed0a..0b39299491b 100644
--- a/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
+++ b/docs/layouts/shortcodes/generated/all_taskmanager_network_section.html
@@ -104,6 +104,12 @@
             <td>Integer</td>
             <td>Number of max overdraft network buffers to use for each ResultPartition. The overdraft buffers will be used when the subtask cannot apply to the normal buffers  due to back pressure, while subtask is performing an action that can not be interrupted in the middle,  like serializing a large record, flatMap operator producing multiple records for one single input record or processing time timer producing large output. In situations like that system will allow subtask to requ [...]
         </tr>
+        <tr>
+            <td><h5>taskmanager.network.memory.read-buffer.required-per-gate.max</h5></td>
+            <td style="word-wrap: break-word;">(none)</td>
+            <td>Integer</td>
+            <td>The maximum number of network read buffers that are required by an input gate. (An input gate is responsible for reading data from all subtasks of an upstream task.) The number of buffers needed by an input gate is dynamically calculated in runtime, depending on various factors (e.g., the parallelism of the upstream task). Among the calculated number of needed buffers, the part below this configured value is required, while the excess part, if any, is optional. A task wil [...]
+        </tr>
         <tr>
             <td><h5>taskmanager.network.netty.client.connectTimeoutSec</h5></td>
             <td style="word-wrap: break-word;">120</td>
diff --git a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
index adcdef70d27..4d1cd4de409 100644
--- a/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
+++ b/docs/layouts/shortcodes/generated/netty_shuffle_environment_configuration.html
@@ -92,6 +92,12 @@
             <td>Integer</td>
             <td>Number of max overdraft network buffers to use for each ResultPartition. The overdraft buffers will be used when the subtask cannot apply to the normal buffers  due to back pressure, while subtask is performing an action that can not be interrupted in the middle,  like serializing a large record, flatMap operator producing multiple records for one single input record or processing time timer producing large output. In situations like that system will allow subtask to requ [...]
         </tr>
+        <tr>
+            <td><h5>taskmanager.network.memory.read-buffer.required-per-gate.max</h5></td>
+            <td style="word-wrap: break-word;">(none)</td>
+            <td>Integer</td>
+            <td>The maximum number of network read buffers that are required by an input gate. (An input gate is responsible for reading data from all subtasks of an upstream task.) The number of buffers needed by an input gate is dynamically calculated in runtime, depending on various factors (e.g., the parallelism of the upstream task). Among the calculated number of needed buffers, the part below this configured value is required, while the excess part, if any, is optional. A task wil [...]
+        </tr>
         <tr>
             <td><h5>taskmanager.network.netty.client.connectTimeoutSec</h5></td>
             <td style="word-wrap: break-word;">120</td>
diff --git a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
index 5e943363d52..b4812e18c02 100644
--- a/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
+++ b/flink-core/src/main/java/org/apache/flink/configuration/NettyShuffleEnvironmentOptions.java
@@ -225,6 +225,35 @@ public class NettyShuffleEnvironmentOptions {
                                     + " help relieve back-pressure caused by unbalanced data distribution among the subpartitions. This value should be"
                                     + " increased in case of higher round trip times between nodes and/or larger number of machines in the cluster.");
 
+    /**
+     * Maximum number of network buffers to use for each outgoing/incoming gate (result
+     * partition/input gate), which contains all exclusive network buffers for all subpartitions and
+     * all floating buffers for the gate. The exclusive network buffers for one channel is
+     * configured by {@link #NETWORK_BUFFERS_PER_CHANNEL} and the floating buffers for one gate is
+     * configured by {@link #NETWORK_EXTRA_BUFFERS_PER_GATE}.
+     */
+    @Experimental
+    @Documentation.Section(Documentation.Sections.ALL_TASK_MANAGER_NETWORK)
+    public static final ConfigOption<Integer> NETWORK_READ_MAX_REQUIRED_BUFFERS_PER_GATE =
+            key("taskmanager.network.memory.read-buffer.required-per-gate.max")
+                    .intType()
+                    .noDefaultValue()
+                    .withDescription(
+                            "The maximum number of network read buffers that are required by an"
+                                    + " input gate. (An input gate is responsible for reading data"
+                                    + " from all subtasks of an upstream task.) The number of buffers"
+                                    + " needed by an input gate is dynamically calculated in runtime,"
+                                    + " depending on various factors (e.g., the parallelism of the"
+                                    + " upstream task). Among the calculated number of needed buffers,"
+                                    + " the part below this configured value is required, while the"
+                                    + " excess part, if any, is optional. A task will fail if the"
+                                    + " required buffers cannot be obtained in runtime. A task will"
+                                    + " not fail due to not obtaining optional buffers, but may"
+                                    + " suffer a performance reduction. If not explicitly configured,"
+                                    + " the default value is Integer.MAX_VALUE for streaming workloads,"
+                                    + " and 1000 for batch workloads. If explicitly configured, the"
+                                    + " configured value should be at least 1.");
+
     /**
      * Minimum number of network buffers required per blocking result partition for sort-shuffle.
      */
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpec.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpec.java
new file mode 100644
index 00000000000..0579414c7c6
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpec.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.consumer;
+
+/**
+ * The buffer specs of the {@link InputGate} include exclusive buffers per channel, required/total
+ * floating buffers and the target of total buffers.
+ */
+public class GateBuffersSpec {
+
+    private final int effectiveExclusiveBuffersPerChannel;
+
+    private final int requiredFloatingBuffers;
+
+    private final int totalFloatingBuffers;
+
+    private final int targetTotalBuffersPerGate;
+
+    GateBuffersSpec(
+            int effectiveExclusiveBuffersPerChannel,
+            int requiredFloatingBuffers,
+            int totalFloatingBuffers,
+            int targetTotalBuffersPerGate) {
+        this.effectiveExclusiveBuffersPerChannel = effectiveExclusiveBuffersPerChannel;
+        this.requiredFloatingBuffers = requiredFloatingBuffers;
+        this.totalFloatingBuffers = totalFloatingBuffers;
+        this.targetTotalBuffersPerGate = targetTotalBuffersPerGate;
+    }
+
+    int getRequiredFloatingBuffers() {
+        return requiredFloatingBuffers;
+    }
+
+    int getTotalFloatingBuffers() {
+        return totalFloatingBuffers;
+    }
+
+    int getEffectiveExclusiveBuffersPerChannel() {
+        return effectiveExclusiveBuffersPerChannel;
+    }
+
+    public int targetTotalBuffersPerGate() {
+        return targetTotalBuffersPerGate;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateSpecUitls.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateSpecUitls.java
new file mode 100644
index 00000000000..55fe71d9234
--- /dev/null
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/InputGateSpecUitls.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.consumer;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+
+import java.util.Optional;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.util.Preconditions.checkState;
+
+/** Utils to manage the specs of the {@link InputGate}, for example, {@link GateBuffersSpec}. */
+public class InputGateSpecUitls {
+
+    public static final int DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_BATCH = 1000;
+
+    public static final int DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_STREAM = Integer.MAX_VALUE;
+
+    public static GateBuffersSpec createGateBuffersSpec(
+            Optional<Integer> configuredMaxRequiredBuffersPerGate,
+            int configuredNetworkBuffersPerChannel,
+            int configuredFloatingNetworkBuffersPerGate,
+            ResultPartitionType partitionType,
+            int numInputChannels) {
+        int maxRequiredBuffersThresholdPerGate =
+                getEffectiveMaxRequiredBuffersPerGate(
+                        partitionType, configuredMaxRequiredBuffersPerGate);
+        int targetRequiredBuffersPerGate =
+                getRequiredBuffersTargetPerGate(
+                        numInputChannels, configuredNetworkBuffersPerChannel);
+        int targetTotalBuffersPerGate =
+                getTotalBuffersTargetPerGate(
+                        numInputChannels,
+                        configuredNetworkBuffersPerChannel,
+                        configuredFloatingNetworkBuffersPerGate);
+        int requiredBuffersPerGate =
+                Math.min(maxRequiredBuffersThresholdPerGate, targetRequiredBuffersPerGate);
+
+        int effectiveExclusiveBuffersPerChannel =
+                getExclusiveBuffersPerChannel(
+                        configuredNetworkBuffersPerChannel,
+                        numInputChannels,
+                        requiredBuffersPerGate);
+        int effectiveExclusiveBuffersPerGate =
+                getEffectiveExclusiveBuffersPerGate(
+                        numInputChannels, effectiveExclusiveBuffersPerChannel);
+
+        int requiredFloatingBuffers = requiredBuffersPerGate - effectiveExclusiveBuffersPerGate;
+        int totalFloatingBuffers = targetTotalBuffersPerGate - effectiveExclusiveBuffersPerGate;
+
+        checkState(requiredFloatingBuffers > 0, "Must be positive.");
+        checkState(
+                requiredFloatingBuffers <= totalFloatingBuffers,
+                "Wrong number of floating buffers.");
+
+        return new GateBuffersSpec(
+                effectiveExclusiveBuffersPerChannel,
+                requiredFloatingBuffers,
+                totalFloatingBuffers,
+                targetTotalBuffersPerGate);
+    }
+
+    @VisibleForTesting
+    static int getEffectiveMaxRequiredBuffersPerGate(
+            ResultPartitionType partitionType,
+            Optional<Integer> configuredMaxRequiredBuffersPerGate) {
+        return configuredMaxRequiredBuffersPerGate.orElseGet(
+                () ->
+                        partitionType.isPipelinedOrPipelinedBoundedResultPartition()
+                                ? DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_STREAM
+                                : DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_BATCH);
+    }
+
+    /**
+     * Since at least one floating buffer is required, the number of required buffers is reduced by
+     * 1, and then the average number of buffers per channel is calculated. Returning the minimum
+     * value to ensure that the number of required buffers per gate is not more than the given
+     * requiredBuffersPerGate.}.
+     */
+    private static int getExclusiveBuffersPerChannel(
+            int configuredNetworkBuffersPerChannel,
+            int numInputChannels,
+            int requiredBuffersPerGate) {
+        checkArgument(numInputChannels > 0, "Must be positive.");
+        checkArgument(requiredBuffersPerGate >= 1, "Require at least 1 buffer per gate.");
+        return Math.min(
+                configuredNetworkBuffersPerChannel,
+                (requiredBuffersPerGate - 1) / numInputChannels);
+    }
+
+    private static int getRequiredBuffersTargetPerGate(
+            int numInputChannels, int configuredNetworkBuffersPerChannel) {
+        return numInputChannels * configuredNetworkBuffersPerChannel + 1;
+    }
+
+    private static int getTotalBuffersTargetPerGate(
+            int numInputChannels,
+            int configuredNetworkBuffersPerChannel,
+            int configuredFloatingBuffersPerGate) {
+        return numInputChannels * configuredNetworkBuffersPerChannel
+                + configuredFloatingBuffersPerGate;
+    }
+
+    private static int getEffectiveExclusiveBuffersPerGate(
+            int numInputChannels, int effectiveExclusiveBuffersPerChannel) {
+        return effectiveExclusiveBuffersPerChannel * numInputChannels;
+    }
+}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
index dd02b7a6498..04378c99a40 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateFactory.java
@@ -54,7 +54,9 @@ import org.slf4j.LoggerFactory;
 import javax.annotation.Nonnull;
 
 import java.io.IOException;
+import java.util.Optional;
 
+import static org.apache.flink.runtime.io.network.partition.consumer.InputGateSpecUitls.createGateBuffersSpec;
 import static org.apache.flink.runtime.shuffle.ShuffleUtils.applyWithShuffleTypeCheck;
 
 /** Factory for {@link SingleInputGate} to use in {@link NettyShuffleEnvironment}. */
@@ -75,7 +77,9 @@ public class SingleInputGateFactory {
 
     @Nonnull protected final NetworkBufferPool networkBufferPool;
 
-    protected final int networkBuffersPerChannel;
+    private final Optional<Integer> maxRequiredBuffersPerGate;
+
+    protected final int configuredNetworkBuffersPerChannel;
 
     private final int floatingNetworkBuffersPerGate;
 
@@ -97,7 +101,8 @@ public class SingleInputGateFactory {
         this.taskExecutorResourceId = taskExecutorResourceId;
         this.partitionRequestInitialBackoff = networkConfig.partitionRequestInitialBackoff();
         this.partitionRequestMaxBackoff = networkConfig.partitionRequestMaxBackoff();
-        this.networkBuffersPerChannel =
+        this.maxRequiredBuffersPerGate = networkConfig.maxRequiredBuffersPerGate();
+        this.configuredNetworkBuffersPerChannel =
                 NettyShuffleUtils.getNetworkBuffersPerInputChannel(
                         networkConfig.networkBuffersPerChannel());
         this.floatingNetworkBuffersPerGate = networkConfig.floatingNetworkBuffersPerGate();
@@ -118,8 +123,20 @@ public class SingleInputGateFactory {
             @Nonnull InputGateDeploymentDescriptor igdd,
             @Nonnull PartitionProducerStateProvider partitionProducerStateProvider,
             @Nonnull InputChannelMetrics metrics) {
+        GateBuffersSpec gateBuffersSpec =
+                createGateBuffersSpec(
+                        maxRequiredBuffersPerGate,
+                        configuredNetworkBuffersPerChannel,
+                        floatingNetworkBuffersPerGate,
+                        igdd.getConsumedPartitionType(),
+                        calculateNumChannels(
+                                igdd.getShuffleDescriptors().length,
+                                igdd.getConsumedSubpartitionIndexRange()));
         SupplierWithException<BufferPool, IOException> bufferPoolFactory =
-                createBufferPoolFactory(networkBufferPool, floatingNetworkBuffersPerGate);
+                createBufferPoolFactory(
+                        networkBufferPool,
+                        gateBuffersSpec.getRequiredFloatingBuffers(),
+                        gateBuffersSpec.getTotalFloatingBuffers());
 
         BufferDecompressor bufferDecompressor = null;
         if (igdd.getConsumedPartitionType().supportCompression()
@@ -149,7 +166,8 @@ public class SingleInputGateFactory {
                         maybeCreateBufferDebloater(
                                 owningTaskName, gateIndex, networkInputGroup.addGroup(gateIndex)));
 
-        createInputChannels(owningTaskName, igdd, inputGate, subpartitionIndexRange, metrics);
+        createInputChannels(
+                owningTaskName, igdd, inputGate, subpartitionIndexRange, gateBuffersSpec, metrics);
         return inputGate;
     }
 
@@ -180,6 +198,7 @@ public class SingleInputGateFactory {
             InputGateDeploymentDescriptor inputGateDeploymentDescriptor,
             SingleInputGate inputGate,
             IndexRange subpartitionIndexRange,
+            GateBuffersSpec gateBuffersSpec,
             InputChannelMetrics metrics) {
         ShuffleDescriptor[] shuffleDescriptors =
                 inputGateDeploymentDescriptor.getShuffleDescriptors();
@@ -200,6 +219,7 @@ public class SingleInputGateFactory {
                         createInputChannel(
                                 inputGate,
                                 channelIdx,
+                                gateBuffersSpec.getEffectiveExclusiveBuffersPerChannel(),
                                 shuffleDescriptors[i],
                                 subpartitionIndex,
                                 channelStatistics,
@@ -220,6 +240,7 @@ public class SingleInputGateFactory {
     private InputChannel createInputChannel(
             SingleInputGate inputGate,
             int index,
+            int buffersPerChannel,
             ShuffleDescriptor shuffleDescriptor,
             int consumedSubpartitionIndex,
             ChannelStatistics channelStatistics,
@@ -239,13 +260,14 @@ public class SingleInputGateFactory {
                             connectionManager,
                             partitionRequestInitialBackoff,
                             partitionRequestMaxBackoff,
-                            networkBuffersPerChannel,
+                            buffersPerChannel,
                             metrics);
                 },
                 nettyShuffleDescriptor ->
                         createKnownInputChannel(
                                 inputGate,
                                 index,
+                                buffersPerChannel,
                                 nettyShuffleDescriptor,
                                 consumedSubpartitionIndex,
                                 channelStatistics,
@@ -262,6 +284,7 @@ public class SingleInputGateFactory {
     protected InputChannel createKnownInputChannel(
             SingleInputGate inputGate,
             int index,
+            int buffersPerChannel,
             NettyShuffleDescriptor inputChannelDescriptor,
             int consumedSubpartitionIndex,
             ChannelStatistics channelStatistics,
@@ -279,7 +302,7 @@ public class SingleInputGateFactory {
                     taskEventPublisher,
                     partitionRequestInitialBackoff,
                     partitionRequestMaxBackoff,
-                    networkBuffersPerChannel,
+                    buffersPerChannel,
                     metrics);
         } else {
             // Different instances => remote
@@ -293,17 +316,17 @@ public class SingleInputGateFactory {
                     connectionManager,
                     partitionRequestInitialBackoff,
                     partitionRequestMaxBackoff,
-                    networkBuffersPerChannel,
+                    buffersPerChannel,
                     metrics);
         }
     }
 
     @VisibleForTesting
     static SupplierWithException<BufferPool, IOException> createBufferPoolFactory(
-            BufferPoolFactory bufferPoolFactory, int floatingNetworkBuffersPerGate) {
-        Pair<Integer, Integer> pair =
-                NettyShuffleUtils.getMinMaxFloatingBuffersPerInputGate(
-                        floatingNetworkBuffersPerGate);
+            BufferPoolFactory bufferPoolFactory,
+            int minFloatingBuffersPerGate,
+            int maxFloatingBuffersPerGate) {
+        Pair<Integer, Integer> pair = Pair.of(minFloatingBuffersPerGate, maxFloatingBuffersPerGate);
         return () -> bufferPoolFactory.createBufferPool(pair.getLeft(), pair.getRight());
     }
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index c1e58745aad..41cf1f56918 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -93,26 +93,44 @@ public class SsgNetworkMemoryCalculationUtils {
     private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(
             ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
 
+        Map<IntermediateDataSetID, Integer> partitionReuseCount = getPartitionReuseCount(ejv);
         Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>();
         Map<IntermediateDataSetID, Integer> maxSubpartitionNums = new HashMap<>();
+        Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes = new HashMap<>();
         Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = new HashMap<>();
 
         if (ejv.getGraph().isDynamic()) {
-            getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums);
+            getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums, inputPartitionTypes);
             getMaxSubpartitionInfoForDynamicGraph(ejv, maxSubpartitionNums, partitionTypes);
         } else {
-            getMaxInputChannelInfo(ejv, maxInputChannelNums);
+            getMaxInputChannelInfo(ejv, maxInputChannelNums, inputPartitionTypes);
             getMaxSubpartitionInfo(ejv, maxSubpartitionNums, partitionTypes, ejvs);
         }
 
         JobVertex jv = ejv.getJobVertex();
 
         return TaskInputsOutputsDescriptor.from(
-                jv.getNumberOfInputs(), maxInputChannelNums, maxSubpartitionNums, partitionTypes);
+                jv.getNumberOfInputs(),
+                maxInputChannelNums,
+                partitionReuseCount,
+                maxSubpartitionNums,
+                inputPartitionTypes,
+                partitionTypes);
+    }
+
+    private static Map<IntermediateDataSetID, Integer> getPartitionReuseCount(
+            ExecutionJobVertex ejv) {
+        Map<IntermediateDataSetID, Integer> partitionReuseCount = new HashMap<>();
+        for (IntermediateResult intermediateResult : ejv.getInputs()) {
+            partitionReuseCount.merge(intermediateResult.getId(), 1, Integer::sum);
+        }
+        return partitionReuseCount;
     }
 
     private static void getMaxInputChannelInfo(
-            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) {
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxInputChannelNums,
+            Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes) {
 
         List<JobEdge> inputEdges = ejv.getJobVertex().getInputs();
 
@@ -128,7 +146,8 @@ public class SsgNetworkMemoryCalculationUtils {
                             ejv.getParallelism(),
                             consumedResult.getNumberOfAssignedPartitions(),
                             inputEdge.getDistributionPattern());
-            maxInputChannelNums.merge(consumedResult.getId(), maxNum, Integer::sum);
+            maxInputChannelNums.put(consumedResult.getId(), maxNum);
+            inputPartitionTypes.putIfAbsent(consumedResult.getId(), consumedResult.getResultType());
         }
     }
 
@@ -162,11 +181,11 @@ public class SsgNetworkMemoryCalculationUtils {
 
     @VisibleForTesting
     static void getMaxInputChannelInfoForDynamicGraph(
-            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) {
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxInputChannelNums,
+            Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes) {
 
         for (ExecutionVertex vertex : ejv.getTaskVertices()) {
-            Map<IntermediateDataSetID, Integer> tmp = new HashMap<>();
-
             for (ConsumedPartitionGroup partitionGroup : vertex.getAllConsumedPartitionGroups()) {
 
                 IntermediateResultPartition resultPartition =
@@ -176,14 +195,13 @@ public class SsgNetworkMemoryCalculationUtils {
                                         resultPartition.getIntermediateResult().getId())
                                 .getSubpartitionIndexRange();
 
-                tmp.merge(
+                maxInputChannelNums.merge(
                         partitionGroup.getIntermediateDataSetID(),
                         subpartitionIndexRange.size() * partitionGroup.size(),
-                        Integer::sum);
-            }
-
-            for (Map.Entry<IntermediateDataSetID, Integer> entry : tmp.entrySet()) {
-                maxInputChannelNums.merge(entry.getKey(), entry.getValue(), Integer::max);
+                        Integer::max);
+                inputPartitionTypes.putIfAbsent(
+                        partitionGroup.getIntermediateDataSetID(),
+                        partitionGroup.getResultPartitionType());
             }
         }
     }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
index 59f83c73bf9..7beb385d395 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleMaster.java
@@ -28,8 +28,10 @@ import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor.NetworkPartitionC
 import org.apache.flink.runtime.shuffle.NettyShuffleDescriptor.PartitionConnectionInfo;
 import org.apache.flink.runtime.util.ConfigurationParserUtils;
 
+import java.util.Optional;
 import java.util.concurrent.CompletableFuture;
 
+import static org.apache.flink.util.Preconditions.checkArgument;
 import static org.apache.flink.util.Preconditions.checkNotNull;
 
 /** Default {@link ShuffleMaster} for netty and local file based shuffle implementation. */
@@ -37,7 +39,9 @@ public class NettyShuffleMaster implements ShuffleMaster<NettyShuffleDescriptor>
 
     private final int buffersPerInputChannel;
 
-    private final int buffersPerInputGate;
+    private final int floatingBuffersPerGate;
+
+    private final Optional<Integer> maxRequiredBuffersPerGate;
 
     private final int sortShuffleMinParallelism;
 
@@ -49,14 +53,29 @@ public class NettyShuffleMaster implements ShuffleMaster<NettyShuffleDescriptor>
         checkNotNull(conf);
         buffersPerInputChannel =
                 conf.getInteger(NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL);
-        buffersPerInputGate =
+        floatingBuffersPerGate =
                 conf.getInteger(NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE);
+        maxRequiredBuffersPerGate =
+                conf.getOptional(
+                        NettyShuffleEnvironmentOptions.NETWORK_READ_MAX_REQUIRED_BUFFERS_PER_GATE);
         sortShuffleMinParallelism =
                 conf.getInteger(
                         NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_PARALLELISM);
         sortShuffleMinBuffers =
                 conf.getInteger(NettyShuffleEnvironmentOptions.NETWORK_SORT_SHUFFLE_MIN_BUFFERS);
         networkBufferSize = ConfigurationParserUtils.getPageSize(conf);
+
+        checkArgument(
+                !maxRequiredBuffersPerGate.isPresent() || maxRequiredBuffersPerGate.get() >= 1,
+                String.format(
+                        "At least one buffer is required for each gate, please increase the value of %s.",
+                        NettyShuffleEnvironmentOptions.NETWORK_READ_MAX_REQUIRED_BUFFERS_PER_GATE
+                                .key()));
+        checkArgument(
+                floatingBuffersPerGate >= 1,
+                String.format(
+                        "The configured floating buffer should be at least 1, please increase the value of %s.",
+                        NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE.key()));
     }
 
     @Override
@@ -104,19 +123,17 @@ public class NettyShuffleMaster implements ShuffleMaster<NettyShuffleDescriptor>
     public MemorySize computeShuffleMemorySizeForTask(TaskInputsOutputsDescriptor desc) {
         checkNotNull(desc);
 
-        int numTotalInputChannels =
-                desc.getInputChannelNums().values().stream().mapToInt(Integer::intValue).sum();
-        int numTotalInputGates = desc.getInputGateNums();
-
         int numRequiredNetworkBuffers =
                 NettyShuffleUtils.computeNetworkBuffersForAnnouncing(
                         buffersPerInputChannel,
-                        buffersPerInputGate,
+                        floatingBuffersPerGate,
+                        maxRequiredBuffersPerGate,
                         sortShuffleMinParallelism,
                         sortShuffleMinBuffers,
-                        numTotalInputChannels,
-                        numTotalInputGates,
+                        desc.getInputChannelNums(),
+                        desc.getPartitionReuseCount(),
                         desc.getSubpartitionNums(),
+                        desc.getInputPartitionTypes(),
                         desc.getPartitionTypes());
 
         return new MemorySize((long) networkBufferSize * numRequiredNetworkBuffers);
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleUtils.java
index 9950f86ea4e..947f0a976a5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/NettyShuffleUtils.java
@@ -20,13 +20,17 @@ package org.apache.flink.runtime.shuffle;
 
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.io.network.partition.consumer.GateBuffersSpec;
 import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
 
 import org.apache.commons.lang3.tuple.Pair;
 
 import java.util.Map;
+import java.util.Optional;
 
-import static org.apache.flink.util.Preconditions.checkArgument;
+import static org.apache.flink.runtime.io.network.partition.consumer.InputGateSpecUitls.createGateBuffersSpec;
+import static org.apache.flink.util.Preconditions.checkNotNull;
+import static org.apache.flink.util.Preconditions.checkState;
 
 /**
  * Utils to calculate network memory requirement of a vertex from network configuration and details
@@ -86,25 +90,37 @@ public class NettyShuffleUtils {
     public static int computeNetworkBuffersForAnnouncing(
             final int numBuffersPerChannel,
             final int numFloatingBuffersPerGate,
+            final Optional<Integer> maxRequiredBuffersPerGate,
             final int sortShuffleMinParallelism,
             final int sortShuffleMinBuffers,
-            final int numTotalInputChannels,
-            final int numTotalInputGates,
+            final Map<IntermediateDataSetID, Integer> inputChannelNums,
+            final Map<IntermediateDataSetID, Integer> partitionReuseCount,
             final Map<IntermediateDataSetID, Integer> subpartitionNums,
+            final Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes,
             final Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
 
-        // Each input channel will retain N exclusive network buffers, N = numBuffersPerChannel.
-        // Each input gate is guaranteed to have a number of floating buffers.
-        int requirementForInputs =
-                getNetworkBuffersPerInputChannel(numBuffersPerChannel) * numTotalInputChannels
-                        + getMinMaxFloatingBuffersPerInputGate(numFloatingBuffersPerGate).getRight()
-                                * numTotalInputGates;
+        int requirementForInputs = 0;
+        for (IntermediateDataSetID dataSetId : inputChannelNums.keySet()) {
+            int numChannels = inputChannelNums.get(dataSetId);
+            ResultPartitionType inputPartitionType = inputPartitionTypes.get(dataSetId);
+            checkNotNull(inputPartitionType);
+
+            int numSingleGateBuffers =
+                    getNumBuffersToAnnounceForInputGate(
+                            inputPartitionType,
+                            numBuffersPerChannel,
+                            numFloatingBuffersPerGate,
+                            maxRequiredBuffersPerGate,
+                            numChannels);
+            checkState(partitionReuseCount.containsKey(dataSetId));
+            requirementForInputs += numSingleGateBuffers * partitionReuseCount.get(dataSetId);
+        }
 
         int requirementForOutputs = 0;
         for (IntermediateDataSetID dataSetId : subpartitionNums.keySet()) {
             int numSubs = subpartitionNums.get(dataSetId);
-            checkArgument(partitionTypes.containsKey(dataSetId));
             ResultPartitionType partitionType = partitionTypes.get(dataSetId);
+            checkNotNull(partitionType);
 
             requirementForOutputs +=
                     getNumBuffersToAnnounceForResultPartition(
@@ -119,6 +135,22 @@ public class NettyShuffleUtils {
         return requirementForInputs + requirementForOutputs;
     }
 
+    private static int getNumBuffersToAnnounceForInputGate(
+            ResultPartitionType type,
+            int configuredNetworkBuffersPerChannel,
+            int floatingNetworkBuffersPerGate,
+            Optional<Integer> maxRequiredBuffersPerGate,
+            int numInputChannels) {
+        GateBuffersSpec gateBuffersSpec =
+                createGateBuffersSpec(
+                        maxRequiredBuffersPerGate,
+                        configuredNetworkBuffersPerChannel,
+                        floatingNetworkBuffersPerGate,
+                        type,
+                        numInputChannels);
+        return gateBuffersSpec.targetTotalBuffersPerGate();
+    }
+
     private static int getNumBuffersToAnnounceForResultPartition(
             ResultPartitionType type,
             int configuredNetworkBuffersPerChannel,
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/TaskInputsOutputsDescriptor.java b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/TaskInputsOutputsDescriptor.java
index 1cd5614e413..ceb11033113 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/TaskInputsOutputsDescriptor.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/shuffle/TaskInputsOutputsDescriptor.java
@@ -35,25 +35,37 @@ public class TaskInputsOutputsDescriptor {
     // Number of input channels per dataSet.
     private final Map<IntermediateDataSetID, Integer> inputChannelNums;
 
+    // Number of the partitions to be re-consumed.
+    Map<IntermediateDataSetID, Integer> partitionReuseCount;
+
     // Number of subpartitions per dataSet.
     private final Map<IntermediateDataSetID, Integer> subpartitionNums;
 
+    // Result partition types of input channels.
+    private final Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes;
+
     // ResultPartitionType per dataSet.
     private final Map<IntermediateDataSetID, ResultPartitionType> partitionTypes;
 
     private TaskInputsOutputsDescriptor(
             int inputGateNums,
             Map<IntermediateDataSetID, Integer> inputChannelNums,
+            Map<IntermediateDataSetID, Integer> partitionReuseCount,
             Map<IntermediateDataSetID, Integer> subpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes,
             Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
 
         checkNotNull(inputChannelNums);
+        checkNotNull(partitionReuseCount);
         checkNotNull(subpartitionNums);
+        checkNotNull(inputPartitionTypes);
         checkNotNull(partitionTypes);
 
         this.inputGateNums = inputGateNums;
         this.inputChannelNums = inputChannelNums;
+        this.partitionReuseCount = partitionReuseCount;
         this.subpartitionNums = subpartitionNums;
+        this.inputPartitionTypes = inputPartitionTypes;
         this.partitionTypes = partitionTypes;
     }
 
@@ -65,10 +77,18 @@ public class TaskInputsOutputsDescriptor {
         return Collections.unmodifiableMap(inputChannelNums);
     }
 
+    public Map<IntermediateDataSetID, Integer> getPartitionReuseCount() {
+        return partitionReuseCount;
+    }
+
     public Map<IntermediateDataSetID, Integer> getSubpartitionNums() {
         return Collections.unmodifiableMap(subpartitionNums);
     }
 
+    public Map<IntermediateDataSetID, ResultPartitionType> getInputPartitionTypes() {
+        return Collections.unmodifiableMap(inputPartitionTypes);
+    }
+
     public Map<IntermediateDataSetID, ResultPartitionType> getPartitionTypes() {
         return Collections.unmodifiableMap(partitionTypes);
     }
@@ -76,10 +96,17 @@ public class TaskInputsOutputsDescriptor {
     public static TaskInputsOutputsDescriptor from(
             int inputGateNums,
             Map<IntermediateDataSetID, Integer> inputChannelNums,
+            Map<IntermediateDataSetID, Integer> partitionReuseCount,
             Map<IntermediateDataSetID, Integer> subpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes,
             Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
 
         return new TaskInputsOutputsDescriptor(
-                inputGateNums, inputChannelNums, subpartitionNums, partitionTypes);
+                inputGateNums,
+                inputChannelNums,
+                partitionReuseCount,
+                subpartitionNums,
+                inputPartitionTypes,
+                partitionTypes);
     }
 }
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NettyShuffleEnvironmentConfiguration.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NettyShuffleEnvironmentConfiguration.java
index 2bc9effdfa0..31a0dcab2e8 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NettyShuffleEnvironmentConfiguration.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/NettyShuffleEnvironmentConfiguration.java
@@ -41,6 +41,9 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
+import java.util.Optional;
+
+import static org.apache.flink.util.Preconditions.checkArgument;
 
 /** Configuration object for the network stack. */
 public class NettyShuffleEnvironmentConfiguration {
@@ -67,6 +70,8 @@ public class NettyShuffleEnvironmentConfiguration {
      */
     private final int floatingNetworkBuffersPerGate;
 
+    private final Optional<Integer> maxRequiredBuffersPerGate;
+
     private final int sortShuffleMinBuffers;
 
     private final int sortShuffleMinParallelism;
@@ -110,6 +115,7 @@ public class NettyShuffleEnvironmentConfiguration {
             int partitionRequestMaxBackoff,
             int networkBuffersPerChannel,
             int floatingNetworkBuffersPerGate,
+            Optional<Integer> maxRequiredBuffersPerGate,
             Duration requestSegmentsTimeout,
             boolean isNetworkDetailedMetrics,
             @Nullable NettyConfig nettyConfig,
@@ -134,6 +140,7 @@ public class NettyShuffleEnvironmentConfiguration {
         this.partitionRequestMaxBackoff = partitionRequestMaxBackoff;
         this.networkBuffersPerChannel = networkBuffersPerChannel;
         this.floatingNetworkBuffersPerGate = floatingNetworkBuffersPerGate;
+        this.maxRequiredBuffersPerGate = maxRequiredBuffersPerGate;
         this.requestSegmentsTimeout = Preconditions.checkNotNull(requestSegmentsTimeout);
         this.isNetworkDetailedMetrics = isNetworkDetailedMetrics;
         this.nettyConfig = nettyConfig;
@@ -180,6 +187,10 @@ public class NettyShuffleEnvironmentConfiguration {
         return floatingNetworkBuffersPerGate;
     }
 
+    public Optional<Integer> maxRequiredBuffersPerGate() {
+        return maxRequiredBuffersPerGate;
+    }
+
     public long batchShuffleReadMemoryBytes() {
         return batchShuffleReadMemoryBytes;
     }
@@ -299,6 +310,10 @@ public class NettyShuffleEnvironmentConfiguration {
                 configuration.getInteger(
                         NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE);
 
+        Optional<Integer> maxRequiredBuffersPerGate =
+                configuration.getOptional(
+                        NettyShuffleEnvironmentOptions.NETWORK_READ_MAX_REQUIRED_BUFFERS_PER_GATE);
+
         int maxBuffersPerChannel =
                 configuration.getInteger(
                         NettyShuffleEnvironmentOptions.NETWORK_MAX_BUFFERS_PER_CHANNEL);
@@ -359,6 +374,19 @@ public class NettyShuffleEnvironmentConfiguration {
                         NettyShuffleEnvironmentOptions
                                 .HYBRID_SHUFFLE_NUM_RETAINED_IN_MEMORY_REGIONS_MAX);
 
+        checkArgument(buffersPerChannel >= 0, "Must be non-negative.");
+        checkArgument(
+                !maxRequiredBuffersPerGate.isPresent() || maxRequiredBuffersPerGate.get() >= 1,
+                String.format(
+                        "At least one buffer is required for each gate, please increase the value of %s.",
+                        NettyShuffleEnvironmentOptions.NETWORK_READ_MAX_REQUIRED_BUFFERS_PER_GATE
+                                .key()));
+        checkArgument(
+                extraBuffersPerGate >= 1,
+                String.format(
+                        "The configured floating buffer should be at least 1, please increase the value of %s.",
+                        NettyShuffleEnvironmentOptions.NETWORK_EXTRA_BUFFERS_PER_GATE.key()));
+
         return new NettyShuffleEnvironmentConfiguration(
                 numberOfNetworkBuffers,
                 pageSize,
@@ -366,6 +394,7 @@ public class NettyShuffleEnvironmentConfiguration {
                 maxRequestBackoff,
                 buffersPerChannel,
                 extraBuffersPerGate,
+                maxRequiredBuffersPerGate,
                 requestSegmentsTimeout,
                 isNetworkDetailedMetrics,
                 nettyConfig,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironmentBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironmentBuilder.java
index 2862ab0c4c5..08488ff0ab3 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironmentBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/NettyShuffleEnvironmentBuilder.java
@@ -31,6 +31,7 @@ import org.apache.flink.runtime.util.EnvironmentInformation;
 import org.apache.flink.util.concurrent.Executors;
 
 import java.time.Duration;
+import java.util.Optional;
 import java.util.concurrent.Executor;
 
 /** Builder for the {@link NettyShuffleEnvironment}. */
@@ -57,6 +58,8 @@ public class NettyShuffleEnvironmentBuilder {
 
     private int floatingNetworkBuffersPerGate = 8;
 
+    private Optional<Integer> maxRequiredBuffersPerGate = Optional.of(Integer.MAX_VALUE);
+
     private int sortShuffleMinBuffers = 100;
 
     private int sortShuffleMinParallelism = Integer.MAX_VALUE;
@@ -131,6 +134,12 @@ public class NettyShuffleEnvironmentBuilder {
         return this;
     }
 
+    public NettyShuffleEnvironmentBuilder setMaxRequiredBuffersPerGate(
+            Optional<Integer> maxRequiredBuffersPerGate) {
+        this.maxRequiredBuffersPerGate = maxRequiredBuffersPerGate;
+        return this;
+    }
+
     public NettyShuffleEnvironmentBuilder setMaxBuffersPerChannel(int maxBuffersPerChannel) {
         this.maxBuffersPerChannel = maxBuffersPerChannel;
         return this;
@@ -230,6 +239,7 @@ public class NettyShuffleEnvironmentBuilder {
                         partitionRequestMaxBackoff,
                         networkBuffersPerChannel,
                         floatingNetworkBuffersPerGate,
+                        maxRequiredBuffersPerGate,
                         DEFAULT_REQUEST_SEGMENTS_TIMEOUT,
                         false,
                         nettyConfig,
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpecTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpecTest.java
new file mode 100644
index 00000000000..8161d00c0d1
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/GateBuffersSpecTest.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition.consumer;
+
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Optional;
+
+import static org.apache.flink.runtime.io.network.partition.consumer.InputGateSpecUitls.getEffectiveMaxRequiredBuffersPerGate;
+import static org.assertj.core.api.Assertions.assertThat;
+
+/** Tests for {@link GateBuffersSpec}. */
+@RunWith(Parameterized.class)
+class GateBuffersSpecTest {
+
+    private static ResultPartitionType[] parameters() {
+        return ResultPartitionType.values();
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testCalculationWithSufficientRequiredBuffers(ResultPartitionType partitionType) {
+        int numInputChannels = 499;
+        GateBuffersSpec gateBuffersSpec = createGateBuffersSpec(numInputChannels, partitionType);
+
+        int minFloating = 1;
+        int maxFloating = 8;
+        int numExclusivePerChannel = 2;
+        int targetTotalBuffersPerGate = 1006;
+
+        checkBuffersInGate(
+                gateBuffersSpec,
+                minFloating,
+                maxFloating,
+                numExclusivePerChannel,
+                targetTotalBuffersPerGate);
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testCalculationWithOneExclusiveBuffer(ResultPartitionType partitionType) {
+        int numInputChannels = 500;
+        GateBuffersSpec gateBuffersSpec = createGateBuffersSpec(numInputChannels, partitionType);
+
+        boolean isPipeline = isPipelineResultPartition(partitionType);
+        int minFloating = isPipeline ? 1 : 500;
+        int maxFloating = isPipelineResultPartition(partitionType) ? 8 : 508;
+        int numExclusivePerChannel = isPipelineResultPartition(partitionType) ? 2 : 1;
+        int targetTotalBuffersPerGate = 1008;
+
+        checkBuffersInGate(
+                gateBuffersSpec,
+                minFloating,
+                maxFloating,
+                numExclusivePerChannel,
+                targetTotalBuffersPerGate);
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testUpperBoundaryCalculationWithOneExclusiveBuffer(ResultPartitionType partitionType) {
+        int numInputChannels = 999;
+        GateBuffersSpec gateBuffersSpec = createGateBuffersSpec(numInputChannels, partitionType);
+
+        int minFloating = 1;
+        int maxFloating = isPipelineResultPartition(partitionType) ? 8 : 1007;
+        int numExclusivePerChannel = isPipelineResultPartition(partitionType) ? 2 : 1;
+        int targetTotalBuffersPerGate = 2006;
+
+        checkBuffersInGate(
+                gateBuffersSpec,
+                minFloating,
+                maxFloating,
+                numExclusivePerChannel,
+                targetTotalBuffersPerGate);
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testBoundaryCalculationWithoutExclusiveBuffer(ResultPartitionType partitionType) {
+        int numInputChannels = 1000;
+        GateBuffersSpec gateBuffersSpec = createGateBuffersSpec(numInputChannels, partitionType);
+
+        boolean isPipeline = isPipelineResultPartition(partitionType);
+        int minFloating = isPipeline ? 1 : 1000;
+        int maxFloating = isPipeline ? 8 : numInputChannels * 2 + 8;
+        int numExclusivePerChannel = isPipeline ? 2 : 0;
+        int targetTotalBuffersPerGate = 2008;
+
+        checkBuffersInGate(
+                gateBuffersSpec,
+                minFloating,
+                maxFloating,
+                numExclusivePerChannel,
+                targetTotalBuffersPerGate);
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testCalculationWithConfiguredZeroExclusiveBuffer(ResultPartitionType partitionType) {
+        int numInputChannels = 1001;
+        int numExclusiveBuffersPerChannel = 0;
+        GateBuffersSpec gateBuffersSpec =
+                createGateBuffersSpec(
+                        numInputChannels, partitionType, numExclusiveBuffersPerChannel);
+
+        int minFloating = 1;
+        int maxFloating = 8;
+        int numExclusivePerChannel = 0;
+        int targetTotalBuffersPerGate = 8;
+
+        checkBuffersInGate(
+                gateBuffersSpec,
+                minFloating,
+                maxFloating,
+                numExclusivePerChannel,
+                targetTotalBuffersPerGate);
+    }
+
+    @ParameterizedTest
+    @MethodSource("parameters")
+    void testConfiguredMaxRequiredBuffersPerGate(ResultPartitionType partitionType) {
+        Optional<Integer> emptyConfig = Optional.empty();
+        int effectiveMaxRequiredBuffers =
+                getEffectiveMaxRequiredBuffersPerGate(partitionType, emptyConfig);
+        int expectEffectiveMaxRequiredBuffers =
+                isPipelineResultPartition(partitionType)
+                        ? InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_STREAM
+                        : InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_BATCH;
+        assertThat(effectiveMaxRequiredBuffers).isEqualTo(expectEffectiveMaxRequiredBuffers);
+
+        Optional<Integer> configuredMaxRequiredBuffers = Optional.of(100);
+        effectiveMaxRequiredBuffers =
+                getEffectiveMaxRequiredBuffersPerGate(partitionType, configuredMaxRequiredBuffers);
+        assertThat(effectiveMaxRequiredBuffers).isEqualTo(configuredMaxRequiredBuffers.get());
+    }
+
+    private static void checkBuffersInGate(
+            GateBuffersSpec gateBuffersSpec,
+            int minFloating,
+            int maxFloating,
+            int numExclusivePerChannel,
+            int targetTotalBuffersPerGate) {
+        assertThat(gateBuffersSpec.getRequiredFloatingBuffers()).isEqualTo(minFloating);
+        assertThat(gateBuffersSpec.getTotalFloatingBuffers()).isEqualTo(maxFloating);
+        assertThat(gateBuffersSpec.getEffectiveExclusiveBuffersPerChannel())
+                .isEqualTo(numExclusivePerChannel);
+        assertThat(gateBuffersSpec.targetTotalBuffersPerGate())
+                .isEqualTo(targetTotalBuffersPerGate);
+    }
+
+    private static GateBuffersSpec createGateBuffersSpec(
+            int numInputChannels, ResultPartitionType partitionType) {
+        return createGateBuffersSpec(numInputChannels, partitionType, 2);
+    }
+
+    private static GateBuffersSpec createGateBuffersSpec(
+            int numInputChannels,
+            ResultPartitionType partitionType,
+            int numExclusiveBuffersPerChannel) {
+        return InputGateSpecUitls.createGateBuffersSpec(
+                getMaxRequiredBuffersPerGate(partitionType),
+                numExclusiveBuffersPerChannel,
+                8,
+                partitionType,
+                numInputChannels);
+    }
+
+    private static Optional<Integer> getMaxRequiredBuffersPerGate(
+            ResultPartitionType partitionType) {
+        return isPipelineResultPartition(partitionType)
+                ? Optional.of(InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_STREAM)
+                : Optional.of(InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_BATCH);
+    }
+
+    private static boolean isPipelineResultPartition(ResultPartitionType partitionType) {
+        return partitionType.isPipelinedOrPipelinedBoundedResultPartition();
+    }
+}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
index b8a052f8e74..8eabc1bc8f1 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateBuilder.java
@@ -111,7 +111,9 @@ public class SingleInputGateBuilder {
         NettyShuffleEnvironmentConfiguration config = environment.getConfiguration();
         this.bufferPoolFactory =
                 SingleInputGateFactory.createBufferPoolFactory(
-                        environment.getNetworkBufferPool(), config.floatingNetworkBuffersPerGate());
+                        environment.getNetworkBufferPool(),
+                        1,
+                        config.floatingNetworkBuffersPerGate());
         this.segmentProvider = environment.getNetworkBufferPool();
         return this;
     }
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
index 6ff3f267268..e591d1e7951 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGateTest.java
@@ -1198,8 +1198,71 @@ public class SingleInputGateTest extends InputGateTestBase {
         assertThat(inputGate.getBuffersInUseCount()).isEqualTo(3);
     }
 
+    @Test
+    void testCalculateInputGateNetworkBuffers() throws Exception {
+        verifyBuffersInBufferPool(true, 2);
+        verifyBuffersInBufferPool(false, 2);
+        verifyBuffersInBufferPool(true, 500);
+        verifyBuffersInBufferPool(false, 500);
+    }
+
     // ---------------------------------------------------------------------------------------------
 
+    private static void verifyBuffersInBufferPool(boolean isPipeline, int subpartitionRandSize)
+            throws Exception {
+        IntermediateResultPartitionID[] partitionIds =
+                new IntermediateResultPartitionID[] {
+                    new IntermediateResultPartitionID(),
+                    new IntermediateResultPartitionID(),
+                    new IntermediateResultPartitionID()
+                };
+
+        IndexRange subpartitionIndexRange = new IndexRange(0, subpartitionRandSize - 1);
+        NettyShuffleEnvironmentBuilder nettyShuffleEnvironmentBuilder =
+                new NettyShuffleEnvironmentBuilder();
+        Optional<Integer> expectMaxRequiredBuffersPerGate =
+                isPipeline
+                        ? Optional.of(
+                                InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_STREAM)
+                        : Optional.of(
+                                InputGateSpecUitls.DEFAULT_MAX_REQUIRED_BUFFERS_PER_GATE_FOR_BATCH);
+        nettyShuffleEnvironmentBuilder.setMaxRequiredBuffersPerGate(
+                expectMaxRequiredBuffersPerGate);
+        NettyShuffleEnvironment netEnv = nettyShuffleEnvironmentBuilder.build();
+
+        SingleInputGate gate =
+                createSingleInputGate(
+                        partitionIds,
+                        isPipeline ? ResultPartitionType.PIPELINED : ResultPartitionType.BLOCKING,
+                        subpartitionIndexRange,
+                        netEnv,
+                        ResourceID.generate(),
+                        new TestingConnectionManager(),
+                        new TestingResultPartitionManager(new NoOpResultSubpartitionView()));
+        gate.setup();
+
+        for (InputChannel inputChannel : gate.getInputChannels().values()) {
+            if (inputChannel instanceof RemoteInputChannel) {
+                assertThat(((RemoteInputChannel) inputChannel).getInitialCredit()).isEqualTo(0);
+            }
+        }
+
+        int targetTotalBuffersPerGate = 2 * partitionIds.length * subpartitionRandSize + 8;
+        int requiredFloatingBuffersPerGate;
+        int totalFloatingBuffersPerGate;
+        if (targetTotalBuffersPerGate >= expectMaxRequiredBuffersPerGate.get()) {
+            requiredFloatingBuffersPerGate = expectMaxRequiredBuffersPerGate.get();
+            totalFloatingBuffersPerGate = targetTotalBuffersPerGate;
+        } else {
+            requiredFloatingBuffersPerGate = 1;
+            totalFloatingBuffersPerGate = 8;
+        }
+        assertThat(gate.getBufferPool().getNumberOfRequiredMemorySegments())
+                .isEqualTo(requiredFloatingBuffersPerGate);
+        assertThat(gate.getBufferPool().getMaxNumberOfMemorySegments())
+                .isEqualTo(totalFloatingBuffersPerGate);
+    }
+
     private static SubpartitionInfo createSubpartitionInfo(
             IntermediateResultPartitionID partitionId) {
         return createSubpartitionInfo(partitionId, 0);
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
index 45d71b8066f..9046d0e61c2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
@@ -96,7 +96,7 @@ public class SsgNetworkMemoryCalculationUtilsTest {
                 new MemorySize(
                         TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 2)
                                 + TestShuffleMaster.computeRequiredShuffleMemoryBytes(1, 6)),
-                new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(10, 0)));
+                new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 0)));
     }
 
     private void testGenerateEnrichedResourceProfile(
@@ -168,7 +168,7 @@ public class SsgNetworkMemoryCalculationUtilsTest {
                         new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(0, 5)),
                         new MemorySize(TestShuffleMaster.computeRequiredShuffleMemoryBytes(5, 20)),
                         new MemorySize(
-                                TestShuffleMaster.computeRequiredShuffleMemoryBytes(30, 0))));
+                                TestShuffleMaster.computeRequiredShuffleMemoryBytes(15, 0))));
     }
 
     private void triggerComputeNumOfSubpartitions(IntermediateResult result) {
@@ -233,11 +233,14 @@ public class SsgNetworkMemoryCalculationUtilsTest {
         eg.initializeJobVertex(consumer, 0L);
 
         Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>();
+        Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes = new HashMap<>();
         SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph(
-                consumer, maxInputChannelNums);
+                consumer, maxInputChannelNums, inputPartitionTypes);
 
         assertThat(maxInputChannelNums.size(), is(1));
         assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels));
+        assertThat(inputPartitionTypes.size(), is(1));
+        assertThat(inputPartitionTypes.get(result.getId()), is(result.getResultType()));
     }
 
     private DefaultExecutionGraph createDynamicExecutionGraph(
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/NettyShuffleUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/NettyShuffleUtilsTest.java
index abcba5687c3..1c215d874be 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/NettyShuffleUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/shuffle/NettyShuffleUtilsTest.java
@@ -45,6 +45,7 @@ import java.io.IOException;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Map;
+import java.util.Optional;
 
 import static org.apache.flink.runtime.executiongraph.ExecutionGraphTestUtils.createExecutionAttemptId;
 import static org.apache.flink.runtime.io.network.partition.ResultPartitionType.BLOCKING;
@@ -64,9 +65,13 @@ public class NettyShuffleUtilsTest extends TestLogger {
     public void testComputeRequiredNetworkBuffers() throws Exception {
         int numBuffersPerChannel = 5;
         int numBuffersPerGate = 8;
+        Optional<Integer> maxRequiredBuffersPerGate = Optional.of(Integer.MAX_VALUE);
         int sortShuffleMinParallelism = 8;
         int numSortShuffleMinBuffers = 12;
 
+        IntermediateDataSetID ids1 = new IntermediateDataSetID();
+        IntermediateDataSetID ids2 = new IntermediateDataSetID();
+
         int numChannels1 = 3;
         int numChannels2 = 4;
 
@@ -81,16 +86,23 @@ public class NettyShuffleUtilsTest extends TestLogger {
                 ImmutableMap.of(ds1, numSubs1, ds2, numSubs2, ds3, numSubs3);
         Map<IntermediateDataSetID, ResultPartitionType> partitionTypes =
                 ImmutableMap.of(ds1, PIPELINED_BOUNDED, ds2, BLOCKING, ds3, BLOCKING);
+        Map<IntermediateDataSetID, Integer> numInputChannels =
+                ImmutableMap.of(ids1, numChannels1, ids2, numChannels2);
+        Map<IntermediateDataSetID, Integer> partitionReuseCount = ImmutableMap.of(ids1, 1, ids2, 1);
+        Map<IntermediateDataSetID, ResultPartitionType> inputPartitionTypes =
+                ImmutableMap.of(ids1, PIPELINED_BOUNDED, ids2, BLOCKING);
 
         int numTotalBuffers =
                 NettyShuffleUtils.computeNetworkBuffersForAnnouncing(
                         numBuffersPerChannel,
                         numBuffersPerGate,
+                        maxRequiredBuffersPerGate,
                         sortShuffleMinParallelism,
                         numSortShuffleMinBuffers,
-                        numChannels1 + numChannels2,
-                        2,
+                        numInputChannels,
+                        partitionReuseCount,
                         subpartitionNums,
+                        inputPartitionTypes,
                         partitionTypes);
 
         NettyShuffleEnvironment sEnv =
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
index 7dda72c3f49..aa10803cbbd 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/io/benchmark/SingleInputGateBenchmarkFactory.java
@@ -63,6 +63,7 @@ public class SingleInputGateBenchmarkFactory extends SingleInputGateFactory {
     protected InputChannel createKnownInputChannel(
             SingleInputGate inputGate,
             int index,
+            int buffersPerChannel,
             NettyShuffleDescriptor inputChannelDescriptor,
             int consumedSubpartitionIndex,
             SingleInputGateFactory.ChannelStatistics channelStatistics,
@@ -89,7 +90,7 @@ public class SingleInputGateBenchmarkFactory extends SingleInputGateFactory {
                     connectionManager,
                     partitionRequestInitialBackoff,
                     partitionRequestMaxBackoff,
-                    networkBuffersPerChannel,
+                    configuredNetworkBuffersPerChannel,
                     metrics);
         }
     }


[flink] 01/03: [FLINK-30471][network] Optimize the enriching network memory process in SsgNetworkMemoryCalculationUtils

Posted by xt...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit fc7defb14d11e270d539ee0d80a2076ae55a4ea2
Author: Yuxin Tan <ta...@gmail.com>
AuthorDate: Mon Dec 12 21:39:20 2022 +0800

    [FLINK-30471][network] Optimize the enriching network memory process in SsgNetworkMemoryCalculationUtils
---
 .../SsgNetworkMemoryCalculationUtils.java          | 67 +++++++++-------------
 .../SsgNetworkMemoryCalculationUtilsTest.java      |  6 +-
 2 files changed, 31 insertions(+), 42 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
index bca093c7219..c1e58745aad 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtils.java
@@ -93,28 +93,27 @@ public class SsgNetworkMemoryCalculationUtils {
     private static TaskInputsOutputsDescriptor buildTaskInputsOutputsDescriptor(
             ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
 
-        Map<IntermediateDataSetID, Integer> maxInputChannelNums;
-        Map<IntermediateDataSetID, Integer> maxSubpartitionNums;
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>();
+        Map<IntermediateDataSetID, Integer> maxSubpartitionNums = new HashMap<>();
+        Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = new HashMap<>();
 
         if (ejv.getGraph().isDynamic()) {
-            maxInputChannelNums = getMaxInputChannelNumsForDynamicGraph(ejv);
-            maxSubpartitionNums = getMaxSubpartitionNumsForDynamicGraph(ejv);
+            getMaxInputChannelInfoForDynamicGraph(ejv, maxInputChannelNums);
+            getMaxSubpartitionInfoForDynamicGraph(ejv, maxSubpartitionNums, partitionTypes);
         } else {
-            maxInputChannelNums = getMaxInputChannelNums(ejv);
-            maxSubpartitionNums = getMaxSubpartitionNums(ejv, ejvs);
+            getMaxInputChannelInfo(ejv, maxInputChannelNums);
+            getMaxSubpartitionInfo(ejv, maxSubpartitionNums, partitionTypes, ejvs);
         }
 
         JobVertex jv = ejv.getJobVertex();
-        Map<IntermediateDataSetID, ResultPartitionType> partitionTypes = getPartitionTypes(jv);
 
         return TaskInputsOutputsDescriptor.from(
                 jv.getNumberOfInputs(), maxInputChannelNums, maxSubpartitionNums, partitionTypes);
     }
 
-    private static Map<IntermediateDataSetID, Integer> getMaxInputChannelNums(
-            ExecutionJobVertex ejv) {
+    private static void getMaxInputChannelInfo(
+            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) {
 
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
         List<JobEdge> inputEdges = ejv.getJobVertex().getInputs();
 
         for (int i = 0; i < inputEdges.size(); i++) {
@@ -129,16 +128,15 @@ public class SsgNetworkMemoryCalculationUtils {
                             ejv.getParallelism(),
                             consumedResult.getNumberOfAssignedPartitions(),
                             inputEdge.getDistributionPattern());
-            ret.merge(consumedResult.getId(), maxNum, Integer::sum);
+            maxInputChannelNums.merge(consumedResult.getId(), maxNum, Integer::sum);
         }
-
-        return ret;
     }
 
-    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNums(
-            ExecutionJobVertex ejv, Function<JobVertexID, ExecutionJobVertex> ejvs) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    private static void getMaxSubpartitionInfo(
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxSubpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> partitionTypes,
+            Function<JobVertexID, ExecutionJobVertex> ejvs) {
         List<IntermediateDataSet> producedDataSets = ejv.getJobVertex().getProducedDataSets();
 
         checkState(!ejv.getGraph().isDynamic(), "Only support non-dynamic graph.");
@@ -157,23 +155,14 @@ public class SsgNetworkMemoryCalculationUtils {
                                 consumerJobVertex.getParallelism(),
                                 outputEdge.getDistributionPattern());
             }
-            ret.put(producedDataSet.getId(), maxNum);
+            maxSubpartitionNums.put(producedDataSet.getId(), maxNum);
+            partitionTypes.putIfAbsent(producedDataSet.getId(), producedDataSet.getResultType());
         }
-
-        return ret;
-    }
-
-    private static Map<IntermediateDataSetID, ResultPartitionType> getPartitionTypes(JobVertex jv) {
-        Map<IntermediateDataSetID, ResultPartitionType> ret = new HashMap<>();
-        jv.getProducedDataSets().forEach(ds -> ret.putIfAbsent(ds.getId(), ds.getResultType()));
-        return ret;
     }
 
     @VisibleForTesting
-    static Map<IntermediateDataSetID, Integer> getMaxInputChannelNumsForDynamicGraph(
-            ExecutionJobVertex ejv) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    static void getMaxInputChannelInfoForDynamicGraph(
+            ExecutionJobVertex ejv, Map<IntermediateDataSetID, Integer> maxInputChannelNums) {
 
         for (ExecutionVertex vertex : ejv.getTaskVertices()) {
             Map<IntermediateDataSetID, Integer> tmp = new HashMap<>();
@@ -194,27 +183,25 @@ public class SsgNetworkMemoryCalculationUtils {
             }
 
             for (Map.Entry<IntermediateDataSetID, Integer> entry : tmp.entrySet()) {
-                ret.merge(entry.getKey(), entry.getValue(), Integer::max);
+                maxInputChannelNums.merge(entry.getKey(), entry.getValue(), Integer::max);
             }
         }
-
-        return ret;
     }
 
-    private static Map<IntermediateDataSetID, Integer> getMaxSubpartitionNumsForDynamicGraph(
-            ExecutionJobVertex ejv) {
-
-        Map<IntermediateDataSetID, Integer> ret = new HashMap<>();
+    private static void getMaxSubpartitionInfoForDynamicGraph(
+            ExecutionJobVertex ejv,
+            Map<IntermediateDataSetID, Integer> maxSubpartitionNums,
+            Map<IntermediateDataSetID, ResultPartitionType> partitionTypes) {
 
         for (IntermediateResult intermediateResult : ejv.getProducedDataSets()) {
             final int maxNum =
                     Arrays.stream(intermediateResult.getPartitions())
                             .map(IntermediateResultPartition::getNumberOfSubpartitions)
                             .reduce(0, Integer::max);
-            ret.put(intermediateResult.getId(), maxNum);
+            maxSubpartitionNums.put(intermediateResult.getId(), maxNum);
+            partitionTypes.putIfAbsent(
+                    intermediateResult.getId(), intermediateResult.getResultType());
         }
-
-        return ret;
     }
 
     /** Private default constructor to avoid being instantiated. */
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
index b81824b89bd..45d71b8066f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/scheduler/SsgNetworkMemoryCalculationUtilsTest.java
@@ -48,6 +48,7 @@ import org.junit.ClassRule;
 import org.junit.Test;
 
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
@@ -231,8 +232,9 @@ public class SsgNetworkMemoryCalculationUtilsTest {
         consumer.setParallelism(decidedConsumerParallelism);
         eg.initializeJobVertex(consumer, 0L);
 
-        Map<IntermediateDataSetID, Integer> maxInputChannelNums =
-                SsgNetworkMemoryCalculationUtils.getMaxInputChannelNumsForDynamicGraph(consumer);
+        Map<IntermediateDataSetID, Integer> maxInputChannelNums = new HashMap<>();
+        SsgNetworkMemoryCalculationUtils.getMaxInputChannelInfoForDynamicGraph(
+                consumer, maxInputChannelNums);
 
         assertThat(maxInputChannelNums.size(), is(1));
         assertThat(maxInputChannelNums.get(result.getId()), is(expectedNumChannels));


[flink] 02/03: [FLINK-30472][network] Modify the default value of the max network memory config option

Posted by xt...@apache.org.
This is an automated email from the ASF dual-hosted git repository.

xtsong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit afdf4a73e43f2e6f5b2dd984aa3471f32658f9d1
Author: Yuxin Tan <ta...@gmail.com>
AuthorDate: Fri Jan 6 17:41:44 2023 +0800

    [FLINK-30472][network] Modify the default value of the max network memory config option
---
 docs/layouts/shortcodes/generated/common_memory_section.html        | 4 ++--
 .../shortcodes/generated/task_manager_memory_configuration.html     | 4 ++--
 .../java/org/apache/flink/configuration/TaskManagerOptions.java     | 6 +++---
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/common_memory_section.html b/docs/layouts/shortcodes/generated/common_memory_section.html
index 75c3312256b..9ee881a7cb6 100644
--- a/docs/layouts/shortcodes/generated/common_memory_section.html
+++ b/docs/layouts/shortcodes/generated/common_memory_section.html
@@ -136,9 +136,9 @@
         </tr>
         <tr>
             <td><h5>taskmanager.memory.network.max</h5></td>
-            <td style="word-wrap: break-word;">1 gb</td>
+            <td style="word-wrap: break-word;">9223372036854775807 bytes</td>
             <td>MemorySize</td>
-            <td>Max Network Memory size for TaskExecutors. Network Memory is off-heap memory reserved for ShuffleEnvironment (e.g., network buffers). Network Memory size is derived to make up the configured fraction of the Total Flink Memory. If the derived size is less/greater than the configured min/max size, the min/max size will be used. The exact size of Network Memory can be explicitly specified by setting the min/max to the same value.</td>
+            <td>Max Network Memory size for TaskExecutors. Network Memory is off-heap memory reserved for ShuffleEnvironment (e.g., network buffers). Network Memory size is derived to make up the configured fraction of the Total Flink Memory. If the derived size is less/greater than the configured min/max size, the min/max size will be used. By default, the max limit of Network Memory is Long.MAX_VALUE. The exact size of Network Memory can be explicitly specified by setting the min/max t [...]
         </tr>
         <tr>
             <td><h5>taskmanager.memory.network.min</h5></td>
diff --git a/docs/layouts/shortcodes/generated/task_manager_memory_configuration.html b/docs/layouts/shortcodes/generated/task_manager_memory_configuration.html
index 99b3e9d04ce..94bd90cecce 100644
--- a/docs/layouts/shortcodes/generated/task_manager_memory_configuration.html
+++ b/docs/layouts/shortcodes/generated/task_manager_memory_configuration.html
@@ -88,9 +88,9 @@
         </tr>
         <tr>
             <td><h5>taskmanager.memory.network.max</h5></td>
-            <td style="word-wrap: break-word;">1 gb</td>
+            <td style="word-wrap: break-word;">9223372036854775807 bytes</td>
             <td>MemorySize</td>
-            <td>Max Network Memory size for TaskExecutors. Network Memory is off-heap memory reserved for ShuffleEnvironment (e.g., network buffers). Network Memory size is derived to make up the configured fraction of the Total Flink Memory. If the derived size is less/greater than the configured min/max size, the min/max size will be used. The exact size of Network Memory can be explicitly specified by setting the min/max to the same value.</td>
+            <td>Max Network Memory size for TaskExecutors. Network Memory is off-heap memory reserved for ShuffleEnvironment (e.g., network buffers). Network Memory size is derived to make up the configured fraction of the Total Flink Memory. If the derived size is less/greater than the configured min/max size, the min/max size will be used. By default, the max limit of Network Memory is Long.MAX_VALUE. The exact size of Network Memory can be explicitly specified by setting the min/max t [...]
         </tr>
         <tr>
             <td><h5>taskmanager.memory.network.min</h5></td>
diff --git a/flink-core/src/main/java/org/apache/flink/configuration/TaskManagerOptions.java b/flink-core/src/main/java/org/apache/flink/configuration/TaskManagerOptions.java
index 8e4e67f0402..66db0e9ce5f 100644
--- a/flink-core/src/main/java/org/apache/flink/configuration/TaskManagerOptions.java
+++ b/flink-core/src/main/java/org/apache/flink/configuration/TaskManagerOptions.java
@@ -504,15 +504,15 @@ public class TaskManagerOptions {
     public static final ConfigOption<MemorySize> NETWORK_MEMORY_MAX =
             key("taskmanager.memory.network.max")
                     .memoryType()
-                    .defaultValue(MemorySize.parse("1g"))
+                    .defaultValue(MemorySize.MAX_VALUE)
                     .withDeprecatedKeys(
                             NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_MEMORY_MAX.key())
                     .withDescription(
                             "Max Network Memory size for TaskExecutors. Network Memory is off-heap memory reserved for"
                                     + " ShuffleEnvironment (e.g., network buffers). Network Memory size is derived to make up the configured"
                                     + " fraction of the Total Flink Memory. If the derived size is less/greater than the configured min/max"
-                                    + " size, the min/max size will be used. The exact size of Network Memory can be explicitly specified by"
-                                    + " setting the min/max to the same value.");
+                                    + " size, the min/max size will be used. By default, the max limit of Network Memory is Long.MAX_VALUE."
+                                    + " The exact size of Network Memory can be explicitly specified by setting the min/max to the same value.");
 
     /** Fraction of Total Flink Memory to be used as Network Memory. */
     @Documentation.Section(Documentation.Sections.COMMON_MEMORY)