You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@celeborn.apache.org by zh...@apache.org on 2023/01/13 08:38:50 UTC
[incubator-celeborn] branch main updated: [CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)
This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 411ab09f [CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)
411ab09f is described below
commit 411ab09ffb3d645716706b8b19ab1d55abdde079
Author: zhongqiangczq <96...@users.noreply.github.com>
AuthorDate: Fri Jan 13 16:38:46 2023 +0800
[CELEBORN-158][Flink] Add ShuffleServiceFactory to Support MapPartition in … (#1105)
---
.../plugin/flink/RemoteShuffleEnvironment.java | 198 +++++++++++++++++++++
.../flink/RemoteShuffleResultPartitionFactory.java | 185 +++++++++++++++++++
.../plugin/flink/RemoteShuffleServiceFactory.java | 58 +++++-
.../flink/RemoteShuffleServiceFactorySuitJ.java | 58 ++++++
.../org/apache/celeborn/common/CelebornConf.scala | 9 +
docs/configuration/client.md | 1 +
6 files changed, 506 insertions(+), 3 deletions(-)
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
new file mode 100644
index 00000000..0599f433
--- /dev/null
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleEnvironment.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkNotNull;
+import static org.apache.celeborn.plugin.flink.utils.Utils.checkState;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_INPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.METRIC_GROUP_OUTPUT;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.createShuffleIOOwnerMetricGroup;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.executiongraph.PartitionInfo;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.PartitionProducerStateProvider;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleIOOwnerContext;
+import org.apache.flink.util.FlinkRuntimeException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+
+/**
+ * The implementation of {@link ShuffleEnvironment} based on the remote shuffle service, providing
+ * shuffle environment on flink TM side.
+ */
+public class RemoteShuffleEnvironment
+ implements ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RemoteShuffleEnvironment.class);
+
+ /** Network buffer pool for shuffle read and shuffle write. */
+ private final NetworkBufferPool networkBufferPool;
+
+ /** A trivial {@link ResultPartitionManager}. */
+ private final ResultPartitionManager resultPartitionManager;
+
+ /** Factory class to create {@link RemoteShuffleResultPartition}. */
+ private final RemoteShuffleResultPartitionFactory resultPartitionFactory;
+
+ // // /** Factory class to create {@link RemoteShuffleInputGate}. */
+ // private final RemoteShuffleInputGateFactory inputGateFactory;
+
+ /** Whether the shuffle environment is closed. */
+ private boolean isClosed;
+
+ private final Object lock = new Object();
+
+ private final CelebornConf conf;
+
+ /**
+ * @param networkBufferPool Network buffer pool for shuffle read and shuffle write.
+ * @param resultPartitionManager A trivial {@link ResultPartitionManager}.
+ * @param resultPartitionFactory Factory class to create {@link RemoteShuffleResultPartition}. //
+ * * @param inputGateFactory Factory class to create {@link RemoteShuffleInputGate}.
+ */
+ public RemoteShuffleEnvironment(
+ NetworkBufferPool networkBufferPool,
+ ResultPartitionManager resultPartitionManager,
+ RemoteShuffleResultPartitionFactory resultPartitionFactory,
+ // RemoteShuffleInputGateFactory inputGateFactory,
+ CelebornConf conf) {
+
+ this.networkBufferPool = networkBufferPool;
+ this.resultPartitionManager = resultPartitionManager;
+ this.resultPartitionFactory = resultPartitionFactory;
+ // this.inputGateFactory = inputGateFactory;
+ this.conf = conf;
+ this.isClosed = false;
+ }
+
+ @Override
+ public List<ResultPartitionWriter> createResultPartitionWriters(
+ ShuffleIOOwnerContext ownerContext,
+ List<ResultPartitionDeploymentDescriptor> resultPartitionDeploymentDescriptors) {
+
+ synchronized (lock) {
+ checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+
+ ResultPartitionWriter[] resultPartitions =
+ new ResultPartitionWriter[resultPartitionDeploymentDescriptors.size()];
+ for (int index = 0; index < resultPartitions.length; index++) {
+ resultPartitions[index] =
+ resultPartitionFactory.create(
+ ownerContext.getOwnerName(), index,
+ resultPartitionDeploymentDescriptors.get(index), conf);
+ }
+ return Arrays.asList(resultPartitions);
+ }
+ }
+
+ @Override
+ public List<IndexedInputGate> createInputGates(
+ ShuffleIOOwnerContext ownerContext,
+ PartitionProducerStateProvider producerStateProvider,
+ List<InputGateDeploymentDescriptor> inputGateDescriptors) {
+ return null;
+ }
+
+ @Override
+ public void close() {
+ LOG.info("Close RemoteShuffleEnvironment.");
+ synchronized (lock) {
+ try {
+ networkBufferPool.destroyAllBufferPools();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ try {
+ resultPartitionManager.shutdown();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ try {
+ networkBufferPool.destroy();
+ } catch (Throwable t) {
+ LOG.error("Close RemoteShuffleEnvironment failure.", t);
+ }
+ isClosed = true;
+ }
+ }
+
+ @Override
+ public int start() throws IOException {
+ synchronized (lock) {
+ checkState(!isClosed, "The RemoteShuffleEnvironment has already been shut down.");
+ LOG.info("Starting the network environment and its components.");
+ // trivial value.
+ return 1;
+ }
+ }
+
+ @Override
+ public boolean updatePartitionInfo(ExecutionAttemptID consumerID, PartitionInfo partitionInfo) {
+ throw new FlinkRuntimeException("Not implemented yet.");
+ }
+
+ @Override
+ public ShuffleIOOwnerContext createShuffleIOOwnerContext(
+ String ownerName, ExecutionAttemptID executionAttemptID, MetricGroup parentGroup) {
+ MetricGroup nettyGroup = createShuffleIOOwnerMetricGroup(checkNotNull(parentGroup));
+ return new ShuffleIOOwnerContext(
+ checkNotNull(ownerName),
+ checkNotNull(executionAttemptID),
+ parentGroup,
+ nettyGroup.addGroup(METRIC_GROUP_OUTPUT),
+ nettyGroup.addGroup(METRIC_GROUP_INPUT));
+ }
+
+ @Override
+ public void releasePartitionsLocally(Collection<ResultPartitionID> partitionIds) {
+ throw new FlinkRuntimeException("Not implemented yet.");
+ }
+
+ @Override
+ public Collection<ResultPartitionID> getPartitionsOccupyingLocalResources() {
+ return new ArrayList<>();
+ }
+
+ @VisibleForTesting
+ NetworkBufferPool getNetworkBufferPool() {
+ return networkBufferPool;
+ }
+
+ @VisibleForTesting
+ RemoteShuffleResultPartitionFactory getResultPartitionFactory() {
+ return resultPartitionFactory;
+ }
+}
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
new file mode 100644
index 00000000..0fbb92d2
--- /dev/null
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleResultPartitionFactory.java
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
+import org.apache.flink.runtime.io.network.buffer.BufferCompressor;
+import org.apache.flink.runtime.io.network.buffer.BufferPool;
+import org.apache.flink.runtime.io.network.buffer.BufferPoolFactory;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionID;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
+import org.apache.flink.runtime.shuffle.ShuffleDescriptor;
+import org.apache.flink.util.function.SupplierWithException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.common.protocol.CompressionCodec;
+import org.apache.celeborn.plugin.flink.utils.Utils;
+
+/** Factory class to create {@link RemoteShuffleResultPartition}. */
+public class RemoteShuffleResultPartitionFactory {
+
+ private static final Logger LOG =
+ LoggerFactory.getLogger(RemoteShuffleResultPartitionFactory.class);
+
+ public static final int MIN_BUFFERS_PER_PARTITION = 16;
+
+ /** Not used and just for compatibility with Flink pluggable shuffle service. */
+ private final ResultPartitionManager partitionManager;
+
+ /** Network buffer pool used for shuffle write buffers. */
+ private final BufferPoolFactory bufferPoolFactory;
+
+ /** Network buffer size. */
+ private final int networkBufferSize;
+
+ /**
+ * Configured number of buffers for shuffle write, it contains two parts: sorting buffers and
+ * transportation buffers.
+ */
+ private final int numBuffersPerPartition;
+
+ private final String minMemorySize = "8m";
+
+ public RemoteShuffleResultPartitionFactory(
+ CelebornConf celebornConf,
+ ResultPartitionManager partitionManager,
+ BufferPoolFactory bufferPoolFactory,
+ int networkBufferSize) {
+ long configuredMemorySize =
+ org.apache.celeborn.common.util.Utils.byteStringAsBytes(
+ celebornConf.memoryPerResultPartition());
+ long minConfiguredMemorySize =
+ org.apache.celeborn.common.util.Utils.byteStringAsBytes(minMemorySize);
+ if (configuredMemorySize < minConfiguredMemorySize) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per result partition, please increase %s "
+ + "to at least %s.",
+ CelebornConf.MEMORY_PER_RESULT_PARTITION().key(), minConfiguredMemorySize));
+ }
+
+ this.numBuffersPerPartition = Utils.checkedDownCast(configuredMemorySize / networkBufferSize);
+ if (numBuffersPerPartition < MIN_BUFFERS_PER_PARTITION) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Insufficient network memory per partition, please increase %s to at "
+ + "least %d bytes.",
+ CelebornConf.MEMORY_PER_RESULT_PARTITION().key(),
+ networkBufferSize * MIN_BUFFERS_PER_PARTITION));
+ }
+
+ this.partitionManager = partitionManager;
+ this.bufferPoolFactory = bufferPoolFactory;
+ this.networkBufferSize = networkBufferSize;
+ }
+
+ public ResultPartition create(
+ String taskNameWithSubtaskAndId,
+ int partitionIndex,
+ ResultPartitionDeploymentDescriptor desc,
+ CelebornConf celebornConf) {
+ LOG.info(
+ "Create result partition -- number of buffers per result partition={}, "
+ + "number of subpartitions={}.",
+ numBuffersPerPartition,
+ desc.getNumberOfSubpartitions());
+
+ return create(
+ taskNameWithSubtaskAndId,
+ partitionIndex,
+ desc.getShuffleDescriptor().getResultPartitionID(),
+ desc.getPartitionType(),
+ desc.getNumberOfSubpartitions(),
+ desc.getMaxParallelism(),
+ createBufferPoolFactory(),
+ desc.getShuffleDescriptor(),
+ celebornConf,
+ desc.getTotalNumberOfPartitions());
+ }
+
+ private ResultPartition create(
+ String taskNameWithSubtaskAndId,
+ int partitionIndex,
+ ResultPartitionID id,
+ ResultPartitionType type,
+ int numSubpartitions,
+ int maxParallelism,
+ List<SupplierWithException<BufferPool, IOException>> bufferPoolFactories,
+ ShuffleDescriptor shuffleDescriptor,
+ CelebornConf celebornConf,
+ int numMappers) {
+
+ // in flink1.14/1.15, just support LZ4
+ if (celebornConf.shuffleCompressionCodec() != CompressionCodec.LZ4) {
+ throw new IllegalStateException(
+ "Unknown CompressionMethod " + celebornConf.shuffleCompressionCodec());
+ }
+ final BufferCompressor bufferCompressor =
+ new BufferCompressor(networkBufferSize, celebornConf.shuffleCompressionCodec().name());
+ RemoteShuffleDescriptor rsd = (RemoteShuffleDescriptor) shuffleDescriptor;
+ ResultPartition partition =
+ new RemoteShuffleResultPartition(
+ taskNameWithSubtaskAndId,
+ partitionIndex,
+ id,
+ type,
+ numSubpartitions,
+ maxParallelism,
+ networkBufferSize,
+ partitionManager,
+ bufferCompressor,
+ bufferPoolFactories.get(0),
+ new RemoteShuffleOutputGate(
+ rsd,
+ numSubpartitions,
+ networkBufferSize,
+ bufferPoolFactories.get(1),
+ celebornConf,
+ numMappers));
+ LOG.debug("{}: Initialized {}", taskNameWithSubtaskAndId, this);
+ return partition;
+ }
+
+ /**
+ * Used to create 2 buffer pools -- sorting buffer pool (7/8), transportation buffer pool (1/8).
+ */
+ private List<SupplierWithException<BufferPool, IOException>> createBufferPoolFactory() {
+ int numForResultPartition = numBuffersPerPartition * 7 / 8;
+ int numForOutputGate = numBuffersPerPartition - numForResultPartition;
+
+ List<SupplierWithException<BufferPool, IOException>> factories = new ArrayList<>();
+ factories.add(
+ () -> bufferPoolFactory.createBufferPool(numForResultPartition, numForResultPartition));
+ factories.add(() -> bufferPoolFactory.createBufferPool(numForOutputGate, numForOutputGate));
+ return factories;
+ }
+
+ @VisibleForTesting
+ int getNetworkBufferSize() {
+ return networkBufferSize;
+ }
+}
diff --git a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
index 6294f3da..7a93511b 100644
--- a/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
+++ b/client-flink/flink-1.14/src/main/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactory.java
@@ -17,9 +17,27 @@
package org.apache.celeborn.plugin.flink;
+import static org.apache.flink.runtime.io.network.metrics.NettyShuffleMetricFactory.registerShuffleMetrics;
+
+import java.time.Duration;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
+import org.apache.flink.metrics.MetricGroup;
import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionManager;
import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
-import org.apache.flink.runtime.shuffle.*;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.apache.flink.runtime.shuffle.ShuffleMaster;
+import org.apache.flink.runtime.shuffle.ShuffleMasterContext;
+import org.apache.flink.runtime.shuffle.ShuffleServiceFactory;
+import org.apache.flink.runtime.util.ConfigurationParserUtils;
+
+import org.apache.celeborn.common.CelebornConf;
+import org.apache.celeborn.plugin.flink.utils.FlinkUtils;
public class RemoteShuffleServiceFactory
implements ShuffleServiceFactory<
@@ -34,7 +52,41 @@ public class RemoteShuffleServiceFactory
@Override
public ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> createShuffleEnvironment(
ShuffleEnvironmentContext shuffleEnvironmentContext) {
- // TODO
- return null;
+ Configuration configuration = shuffleEnvironmentContext.getConfiguration();
+ int bufferSize = ConfigurationParserUtils.getPageSize(configuration);
+ final int numBuffers =
+ calculateNumberOfNetworkBuffers(
+ shuffleEnvironmentContext.getNetworkMemorySize(), bufferSize);
+
+ ResultPartitionManager resultPartitionManager = new ResultPartitionManager();
+ MetricGroup metricGroup = shuffleEnvironmentContext.getParentMetricGroup();
+
+ Duration requestSegmentsTimeout =
+ Duration.ofMillis(
+ configuration.getLong(
+ NettyShuffleEnvironmentOptions
+ .NETWORK_EXCLUSIVE_BUFFERS_REQUEST_TIMEOUT_MILLISECONDS));
+ NetworkBufferPool networkBufferPool =
+ new NetworkBufferPool(numBuffers, bufferSize, requestSegmentsTimeout);
+
+ registerShuffleMetrics(metricGroup, networkBufferPool);
+ CelebornConf celebornConf = FlinkUtils.toCelebornConf(configuration);
+ RemoteShuffleResultPartitionFactory resultPartitionFactory =
+ new RemoteShuffleResultPartitionFactory(
+ celebornConf, resultPartitionManager, networkBufferPool, bufferSize);
+
+ return new RemoteShuffleEnvironment(
+ networkBufferPool, resultPartitionManager, resultPartitionFactory, celebornConf);
+ }
+
+ private static int calculateNumberOfNetworkBuffers(MemorySize memorySize, int bufferSize) {
+ long numBuffersLong = memorySize.getBytes() / bufferSize;
+ if (numBuffersLong > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException(
+ "The given number of memory bytes ("
+ + memorySize.getBytes()
+ + ") corresponds to more than MAX_INT pages.");
+ }
+ return (int) numBuffersLong;
}
}
diff --git a/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
new file mode 100644
index 00000000..9a4f232f
--- /dev/null
+++ b/client-flink/flink-1.14/src/test/java/org/apache/celeborn/plugin/flink/RemoteShuffleServiceFactorySuitJ.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.plugin.flink;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.MemorySize;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.consumer.IndexedInputGate;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironment;
+import org.apache.flink.runtime.shuffle.ShuffleEnvironmentContext;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class RemoteShuffleServiceFactorySuitJ {
+ @Test
+ public void testCreateShuffleEnvironment() {
+ RemoteShuffleServiceFactory remoteShuffleServiceFactory = new RemoteShuffleServiceFactory();
+ ShuffleEnvironmentContext shuffleEnvironmentContext = mock(ShuffleEnvironmentContext.class);
+ when(shuffleEnvironmentContext.getConfiguration()).thenReturn(new Configuration());
+ when(shuffleEnvironmentContext.getNetworkMemorySize())
+ .thenReturn(new MemorySize(64 * 1024 * 1024));
+ MetricGroup parentMeric = mock(MetricGroup.class);
+ when(shuffleEnvironmentContext.getParentMetricGroup()).thenReturn(parentMeric);
+ MetricGroup childGroup = mock(MetricGroup.class);
+ MetricGroup childChildGroup = mock(MetricGroup.class);
+ when(parentMeric.addGroup(anyString())).thenReturn(childGroup);
+ when(childGroup.addGroup(any())).thenReturn(childChildGroup);
+ when(childChildGroup.gauge(any(), any())).thenReturn(null);
+ ShuffleEnvironment<ResultPartitionWriter, IndexedInputGate> shuffleEnvironment =
+ remoteShuffleServiceFactory.createShuffleEnvironment(shuffleEnvironmentContext);
+ Assert.assertEquals(
+ 32 * 1024,
+ ((RemoteShuffleEnvironment) shuffleEnvironment)
+ .getResultPartitionFactory()
+ .getNetworkBufferSize());
+ }
+}
diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index 4e45b8f3..109cf20c 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -744,6 +744,7 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se
def workerDirectMemoryRatioForReadBuffer: Double = get(WORKER_DIRECT_MEMORY_RATIO_FOR_READ_BUFFER)
def workerDirectMemoryRatioForShuffleStorage: Double =
get(WORKER_DIRECT_MEMORY_RATIO_FOR_SHUFFLE_STORAGE)
+ def memoryPerResultPartition: String = get(MEMORY_PER_RESULT_PARTITION)
/**
* @return workingDir, usable space, flusher thread count, disk type
@@ -2809,4 +2810,12 @@ object CelebornConf extends Logging {
.doc("The time before a cache item is removed.")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("15s")
+
+ val MEMORY_PER_RESULT_PARTITION: ConfigEntry[String] =
+ buildConf("celeborn.client.network.memory.perResultPartition")
+ .categories("client")
+ .version("0.3.0")
+ .doc("The size of network buffers required per result partition. The minimum valid value is 8M. Usually, several hundreds of megabytes memory is enough for large scale batch jobs.")
+ .stringConf
+ .createWithDefault("64m")
}
diff --git a/docs/configuration/client.md b/docs/configuration/client.md
index 212b51e5..0ac32a45 100644
--- a/docs/configuration/client.md
+++ b/docs/configuration/client.md
@@ -23,6 +23,7 @@ license: |
| celeborn.client.blacklistSlave.enabled | true | When true, Celeborn will add partition's peer worker into blacklist when push data to slave failed. | 0.3.0 |
| celeborn.client.closeIdleConnections | true | Whether client will close idle connections. | 0.3.0 |
| celeborn.client.maxRetries | 15 | Max retry times for client to connect master endpoint | 0.2.0 |
+| celeborn.client.network.memory.perResultPartition | 64m | The size of network buffers required per result partition. The minimum valid value is 8M. Usually, several hundreds of megabytes memory is enough for large scale batch jobs. | 0.3.0 |
| celeborn.fetch.maxReqsInFlight | 3 | Amount of in-flight chunk fetch request. | 0.2.0 |
| celeborn.fetch.maxRetries | 3 | Max retries of fetch chunk | 0.2.0 |
| celeborn.fetch.timeout | 120s | Timeout for a task to fetch chunk. | 0.2.0 |