You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/01/13 08:38:50 UTC

[incubator-celeborn] branch main updated: [CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)

This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 411ab09f  [CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)
411ab09f is described below

commit 411ab09ffb3d645716706b8b19ab1d55abdde079
Author: zhongqiangczq <96...@users.noreply.github.com>
AuthorDate: Fri Jan 13 16:38:46 2023 +0800

     [CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)
---
 .../plugin/flink/RemoteShuffleEnvironment.java     | 198 +++++++++++++++++++++
 .../flink/RemoteShuffleResultPartitionFactory.java | 185 +++++++++++++++++++
 .../plugin/flink/RemoteShuffleServiceFactory.java  |  58 +++++-
 .../flink/RemoteShuffleServiceFactorySuitJ.java    |  58 ++++++
 .../org/apache/celeborn/common/CelebornConf.scala  |   9 +
 docs/configuration/client.md                       |   1 +
 6 files changed, 506 insertions(+), 3 deletions(-)

diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
new file mode 100644
index 00000000..0599f433
--- /dev/null
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_INPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_OUTPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.createShuffleIOOwnerMetricGroup;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.PartitionInfo;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+
+/**
+ * The implementation of {@link ShuffleEnvironment} based on the remote shuffle service, providing
+ * shuffle environment on flink TM side.
+ */
+public class RemoteShuffleEnvironment
+    implements ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> {
+
+  private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleEnvironment.class);
+
+  /** Network buffer pool for shuffle read and shuffle write. */
+  private final NetworkBufferPool networkBufferPool;
+
+  /** A trivial {@link ResultPartitionManager}. */
+  private final ResultPartitionManager resultPartitionManager;
+
+  /** Factory class to create {@link RemoteShuffleResultPartition}. */
+  private final RemoteShuffleResultPartitionFactory resultPartitionFactory;
+
+  //    //    /** Factory class to create {@link RemoteShuffleInputGate}. */
+  //    private final RemoteShuffleInputGateFactory inputGateFactory;
+
+  /** Whether the shuffle environment is closed. */
+  private boolean isClosed;
+
+  private final Object lock = new Object();
+
+  private final CelebornConf conf;
+
+  /**
+   * @param networkBufferPool Network buffer pool for shuffle read and shuffle write.
+   * @param resultPartitionManager A trivial {@link ResultPartitionManager}.
+   * @param resultPartitionFactory Factory class to create {@link RemoteShuffleResultPartition}. //
+   *     * @param inputGateFactory Factory class to create {@link RemoteShuffleInputGate}.
+   */
+  public RemoteShuffleEnvironment(
+      NetworkBufferPool networkBufferPool,
+      ResultPartitionManager resultPartitionManager,
+      RemoteShuffleResultPartitionFactory resultPartitionFactory,
+      //        RemoteShuffleInputGateFactory inputGateFactory,
+      CelebornConf conf) {
+
+    this.networkBufferPool = networkBufferPool;
+    this.resultPartitionManager = resultPartitionManager;
+    this.resultPartitionFactory = resultPartitionFactory;
+    //    this.inputGateFactory = inputGateFactory;
+    this.conf = conf;
+    this.isClosed = false;
+  }
+
+  @Override
+  public List<ResultPartitionWriter> createResultPartitionWriters(
+      ShuffleIOOwnerContext ownerContext,
+      List<ResultPartitionDeploymentDescriptor> resultPartitionDeploymentDescriptors) {
+
+    synchronized (lock) {
+      checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+
+      ResultPartitionWriter[] resultPartitions =
+          new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()];
+      for (int index = 0; index < resultPartitions.length; index++) {
+        resultPartitions[index] =
+            resultPartitionFactory.create(
+                ownerContext.getOwnerName(), index,
+                resultPartitionDeploymentDescriptors.get(index), conf);
+      }
+      return Arrays.asList(resultPartitions);
+    }
+  }
+
+  @Override
+  public List<IndexedInputGate> createInputGates(
+      ShuffleIOOwnerContext ownerContext,
+      PartitionProducerStateProvider producerStateProvider,
+      List<InputGateDeploymentDescriptor> inputGateDescriptors) {
+    return null;
+  }
+
+  @Override
+  public void close() {
+    LOG.info("Close RemoteShuffleEnvironment.");
+    synchronized (lock) {
+      try {
+        networkBufferPool.destroyAllBufferPools();
+      } catch (Throwable t) {
+        LOG.error("Close RemoteShuffleEnvironment failure.", t);
+      }
+      try {
+        resultPartitionManager.shutdown();
+      } catch (Throwable t) {
+        LOG.error("Close RemoteShuffleEnvironment failure.", t);
+      }
+      try {
+        networkBufferPool.destroy();
+      } catch (Throwable t) {
+        LOG.error("Close RemoteShuffleEnvironment failure.", t);
+      }
+      isClosed = true;
+    }
+  }
+
+  @Override
+  public int start() throws IOException {
+    synchronized (lock) {
+      checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+      LOG.info("Starting the network environment and its components.");
+      // trivial value.
+      return 1;
+    }
+  }
+
+  @Override
+  public boolean updatePartitionInfo(ExecutionAttemptID consumerID, PartitionInfo partitionInfo) {
+    throw new FlinkRuntimeException("Not implemented yet.");
+  }
+
+  @Override
+  public ShuffleIOOwnerContext createShuffleIOOwnerContext(
+      String ownerName, ExecutionAttemptID executionAttemptID, MetricGroup parentGroup) {
+    MetricGroup nettyGroup = createShuffleIOOwnerMetricGroup(checkNotNull(parentGroup));
+    return new ShuffleIOOwnerContext(
+        checkNotNull(ownerName),
+        checkNotNull(executionAttemptID),
+        parentGroup,
+        nettyGroup.addGroup(METRIC_GROUP_OUTPUT),
+        nettyGroup.addGroup(METRIC_GROUP_INPUT));
+  }
+
+  @Override
+  public void releasePartitionsLocally(Collection<ResultPartitionID> partitionIds) {
+    throw new FlinkRuntimeException("Not implemented yet.");
+  }
+
+  @Override
+  public Collection<ResultPartitionID> getPartitionsOccupyingLocalResources() {
+    return new ArrayList<>();
+  }
+
+  @VisibleForTesting
+  NetworkBufferPool getNetworkBufferPool() {
+    return networkBufferPool;
+  }
+
+  @VisibleForTesting
+  RemoteShuffleResultPartitionFactory getResultPartitionFactory() {
+    return resultPartitionFactory;
+  }
+}
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
new file mode 100644
index 00000000..0fbb92d2
--- /dev/null
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/** Factory class to create {@link RemoteShuffleResultPartition}. */
+public class RemoteShuffleResultPartitionFactory {
+
+  private static final Logger LOG =
+      LoggerFactory.getLogger(RemoteShuffleResultPartitionFactory.class);
+
+  public static final int MIN_BUFFERS_PER_PARTITION = 16;
+
+  /** Not used and just for compatibility with Flink pluggable shuffle service. */
+  private final ResultPartitionManager partitionManager;
+
+  /** Network buffer pool used for shuffle write buffers. */
+  private final BufferPoolFactory bufferPoolFactory;
+
+  /** Network buffer size. */
+  private final int networkBufferSize;
+
+  /**
+   * Configured number of buffers for shuffle write, it contains two parts: sorting buffers and
+   * transportation buffers.
+   */
+  private final int numBuffersPerPartition;
+
+  private final String minMemorySize = "8m";
+
+  public RemoteShuffleResultPartitionFactory(
+      CelebornConf celebornConf,
+      ResultPartitionManager partitionManager,
+      BufferPoolFactory bufferPoolFactory,
+      int networkBufferSize) {
+    long configuredMemorySize =
+        org.apache.celeborn.common.util.Utils.byteStringAsBytes(
+            celebornConf.memoryPerResultPartition());
+    long minConfiguredMemorySize =
+        org.apache.celeborn.common.util.Utils.byteStringAsBytes(minMemorySize);
+    if (configuredMemorySize < minConfiguredMemorySize) {
+      throw new IllegalArgumentException(
+          String.format(
+              "Insufficient network memory per result partition, please increase %s "
+                  + "to at least %s.",
+              CelebornConf.MEMORY_PER_RESULT_PARTITION().key(), minConfiguredMemorySize));
+    }
+
+    this.numBuffersPerPartition = Utils.checkedDownCast(configuredMemorySize / networkBufferSize);
+    if (numBuffersPerPartition < MIN_BUFFERS_PER_PARTITION) {
+      throw new IllegalArgumentException(
+          String.format(
+              "Insufficient network memory per partition, please increase %s to at "
+                  + "least %d bytes.",
+              CelebornConf.MEMORY_PER_RESULT_PARTITION().key(),
+              networkBufferSize * MIN_BUFFERS_PER_PARTITION));
+    }
+
+    this.partitionManager = partitionManager;
+    this.bufferPoolFactory = bufferPoolFactory;
+    this.networkBufferSize = networkBufferSize;
+  }
+
+  public ResultPartition create(
+      String taskNameWithSubtaskAndId,
+      int partitionIndex,
+      ResultPartitionDeploymentDescriptor desc,
+      CelebornConf celebornConf) {
+    LOG.info(
+        "Create result partition -- number of buffers per result partition={}, "
+            + "number of subpartitions={}.",
+        numBuffersPerPartition,
+        desc.getNumberOfSubpartitions());
+
+    return create(
+        taskNameWithSubtaskAndId,
+        partitionIndex,
+        desc.getShuffleDescriptor().getResultPartitionID(),
+        desc.getPartitionType(),
+        desc.getNumberOfSubpartitions(),
+        desc.getMaxParallelism(),
+        createBufferPoolFactory(),
+        desc.getShuffleDescriptor(),
+        celebornConf,
+        desc.getTotalNumberOfPartitions());
+  }
+
+  private ResultPartition create(
+      String taskNameWithSubtaskAndId,
+      int partitionIndex,
+      ResultPartitionID id,
+      ResultPartitionType type,
+      int numSubpartitions,
+      int maxParallelism,
+      List<SupplierWithException<BufferPool, IOException>> bufferPoolFactories,
+      ShuffleDescriptor shuffleDescriptor,
+      CelebornConf celebornConf,
+      int numMappers) {
+
+    // in flink1.14/1.15, just support LZ4
+    if (celebornConf.shuffleCompressionCodec() != CompressionCodec.LZ4) {
+      throw new IllegalStateException(
+          "Unknown CompressionMethod " + celebornConf.shuffleCompressionCodec());
+    }
+    final BufferCompressor bufferCompressor =
+        new BufferCompressor(networkBufferSize, celebornConf.shuffleCompressionCodec().name());
+    RemoteShuffleDescriptor rsd = (RemoteShuffleDescriptor) shuffleDescriptor;
+    ResultPartition partition =
+        new RemoteShuffleResultPartition(
+            taskNameWithSubtaskAndId,
+            partitionIndex,
+            id,
+            type,
+            numSubpartitions,
+            maxParallelism,
+            networkBufferSize,
+            partitionManager,
+            bufferCompressor,
+            bufferPoolFactories.get(0),
+            new RemoteShuffleOutputGate(
+                rsd,
+                numSubpartitions,
+                networkBufferSize,
+                bufferPoolFactories.get(1),
+                celebornConf,
+                numMappers));
+    LOG.debug("{}: Initialized {}", taskNameWithSubtaskAndId, this);
+    return partition;
+  }
+
+  /**
+   * Used to create 2 buffer pools -- sorting buffer pool (7/8), transportation buffer pool (1/8).
+   */
+  private List<SupplierWithException<BufferPool, IOException>> createBufferPoolFactory() {
+    int numForResultPartition = numBuffersPerPartition * 7 / 8;
+    int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+    List<SupplierWithException<BufferPool, IOException>> factories = new ArrayList<>();
+    factories.add(
+        () -> bufferPoolFactory.createBufferPool(numForResultPartition, numForResultPartition));
+    factories.add(() -> bufferPoolFactory.createBufferPool(numForOutputGate, numForOutputGate));
+    return factories;
+  }
+
+  @VisibleForTesting
+  int getNetworkBufferSize() {
+    return networkBufferSize;
+  }
+}
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
index 6294f3da..7a93511b 100644
--- a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
@@ -17,9 +17,27 @@
 
 package org.apache.celeborn.plugin.flink;
 
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.registerShuffleMetrics;
+
+import java.time.Duration;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
+import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
 import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
-import org.apache.flink.runtime.shuffle.*;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.apache.flink.runtime.shuffle.ShuffleMaster;
+import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.apache.flink.runtime.shuffle.ShuffleServiceFactory;
+import org.apache.flink.runtime.util.ConfigurationParserUtils;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
 
 public class RemoteShuffleServiceFactory
     implements ShuffleServiceFactory<
@@ -34,7 +52,41 @@ public class RemoteShuffleServiceFactory
   @Override
   public ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> createShuffleEnvironment(
       ShuffleEnvironmentContext shuffleEnvironmentContext) {
-    // TODO
-    return null;
+    Configuration configuration = shuffleEnvironmentContext.getConfiguration();
+    int bufferSize = ConfigurationParserUtils.getPageSize(configuration);
+    final int numBuffers =
+        calculateNumberOfNetworkBuffers(
+            shuffleEnvironmentContext.getNetworkMemorySize(), bufferSize);
+
+    ResultPartitionManager resultPartitionManager = new ResultPartitionManager();
+    MetricGroup metricGroup = shuffleEnvironmentContext.getParentMetricGroup();
+
+    Duration requestSegmentsTimeout =
+        Duration.ofMillis(
+            configuration.getLong(
+                NettyShuffleEnvironmentOptions
+                    .NETWORK_EXCLUSIVE_BUFFERS_REQUEST_TIMEOUT_MILLISECONDS));
+    NetworkBufferPool networkBufferPool =
+        new NetworkBufferPool(numBuffers, bufferSize, requestSegmentsTimeout);
+
+    registerShuffleMetrics(metricGroup, networkBufferPool);
+    CelebornConf celebornConf = FlinkUtils.toCelebornConf(configuration);
+    RemoteShuffleResultPartitionFactory resultPartitionFactory =
+        new RemoteShuffleResultPartitionFactory(
+            celebornConf, resultPartitionManager, networkBufferPool, bufferSize);
+
+    return new RemoteShuffleEnvironment(
+        networkBufferPool, resultPartitionManager, resultPartitionFactory, celebornConf);
+  }
+
+  private static int calculateNumberOfNetworkBuffers(MemorySize memorySize, int bufferSize) {
+    long numBuffersLong = memorySize.getBytes() / bufferSize;
+    if (numBuffersLong > Integer.MAX_VALUE) {
+      throw new IllegalArgumentException(
+          "The given number of memory bytes ("
+              + memorySize.getBytes()
+              + ") corresponds to more than MAX_INT pages.");
+    }
+    return (int) numBuffersLong;
   }
 }
diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
new file mode 100644
index 00000000..9a4f232f
--- /dev/null
+++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RemoteShuffleServiceFactorySuitJ {
+  @Test
+  public void testCreateShuffleEnvironment() {
+    RemoteShuffleServiceFactory remoteShuffleServiceFactory = new RemoteShuffleServiceFactory();
+    ShuffleEnvironmentContext shuffleEnvironmentContext = mock(ShuffleEnvironmentContext.class);
+    when(shuffleEnvironmentContext.getConfiguration()).thenReturn(new Configuration());
+    when(shuffleEnvironmentContext.getNetworkMemorySize())
+        .thenReturn(new MemorySize(64 * 1024 * 1024));
+    MetricGroup parentMeric = mock(MetricGroup.class);
+    when(shuffleEnvironmentContext.getParentMetricGroup()).thenReturn(parentMeric);
+    MetricGroup childGroup = mock(MetricGroup.class);
+    MetricGroup childChildGroup = mock(MetricGroup.class);
+    when(parentMeric.addGroup(anyString())).thenReturn(childGroup);
+    when(childGroup.addGroup(any())).thenReturn(childChildGroup);
+    when(childChildGroup.gauge(any(), any())).thenReturn(null);
+    ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> shuffleEnvironment =
+        remoteShuffleServiceFactory.createShuffleEnvironment(shuffleEnvironmentContext);
+    Assert.assertEquals(
+        32 * 1024,
+        ((RemoteShuffleEnvironment) shuffleEnvironment)
+            .getResultPartitionFactory()
+            .getNetworkBufferSize());
+  }
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 4e45b8f3..109cf20c 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -744,6 +744,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
   def workerDirectMemoryRatioForReadBuffer: Double = get(WORKER_DIRECT_MEMORY_RATIO_FOR_READ_BUFFER)
   def workerDirectMemoryRatioForShuffleStorage: Double =
     get(WORKER_DIRECT_MEMORY_RATIO_FOR_SHUFFLE_STORAGE)
+  def memoryPerResultPartition: String = get(MEMORY_PER_RESULT_PARTITION)
 
   /**
    * @return workingDir, usable space, flusher thread count, disk type
@@ -2809,4 +2810,12 @@ object CelebornConf extends Logging {
       .doc("The time before a cache item is removed.")
       .timeConf(TimeUnit.MILLISECONDS)
       .createWithDefaultString("15s")
+
+  val MEMORY_PER_RESULT_PARTITION: ConfigEntry[String] =
+    buildConf("celeborn.client.network.memory.perResultPartition")
+      .categories("client")
+      .version("0.3.0")
+      .doc("The size of network buffers required per result partition. The minimum valid value is 8M. Usually, several hundreds of megabytes memory is enough for large scale batch jobs.")
+      .stringConf
+      .createWithDefault("64m")
 }
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 212b51e5..0ac32a45 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -23,6 +23,7 @@ license: |
 | celeborn.client.blacklistSlave.enabled | true | When true, Celeborn will add partition's peer worker into blacklist when push data to slave failed. | 0.3.0 | 
 | celeborn.client.closeIdleConnections | true | Whether client will close idle connections. | 0.3.0 | 
 | celeborn.client.maxRetries | 15 | Max retry times for client to connect master endpoint | 0.2.0 | 
+| celeborn.client.network.memory.perResultPartition | 64m | The size of network buffers required per result partition. The minimum valid value is 8M. Usually, several hundreds of megabytes memory is enough for large scale batch jobs. | 0.3.0 | 
 | celeborn.fetch.maxReqsInFlight | 3 | Amount of in-flight chunk fetch request. | 0.2.0 | 
 | celeborn.fetch.maxRetries | 3 | Max retries of fetch chunk | 0.2.0 | 
 | celeborn.fetch.timeout | 120s | Timeout for a task to fetch chunk. | 0.2.0 |