You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/01/08 18:22:37 UTC
[spark] branch master updated: [SPARK-32917][SHUFFLE][CORE] Adds
support for executors to push shuffle blocks after successful map task
completion
This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d00f069 [SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion
d00f069 is described below
commit d00f0695b7513046e42e47f35b280d7aa494de5b
Author: Chandni Singh <si...@gmail.com>
AuthorDate: Fri Jan 8 12:21:56 2021 -0600
[SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion
### What changes were proposed in this pull request?
This is the shuffle writer side change where executors can push data to remote shuffle services. This is needed for push-based shuffle - SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
Summary of changes:
- This adds support for executors to push shuffle blocks after map tasks complete writing shuffle data.
- This also introduces a timeout specifically for creating connection to remote shuffle services.
### Why are the changes needed?
- These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
- The main reason to create a separate connection creation timeout is because the existing `connectionTimeoutMs` is overloaded and is used for connection creation timeouts as well as connection idle timeout. The connection creation timeout should be much lower than the idle timeouts. The default for `connectionTimeoutMs` is 120s. This is quite high for just establishing the connections. If a shuffle server node is bad then the connection creation will fail within few seconds. However [...]
### Does this PR introduce _any_ user-facing change?
Yes. This PR introduces client-side configs for push-based shuffle. If push-based shuffle is turned-off then the users will not see any change.
### How was this patch tested?
Added unit tests.
The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
We have already verified the functionality and the improved performance as documented in the SPIP doc.
Lead-authored-by: Min Shen mshenlinkedin.com
Co-authored-by: Chandni Singh chsinghlinkedin.com
Co-authored-by: Ye Zhou yezhoulinkedin.com
Closes #30312 from otterc/SPARK-32917.
Lead-authored-by: Chandni Singh <si...@gmail.com>
Co-authored-by: Chandni Singh <ch...@linkedin.com>
Co-authored-by: Min Shen <ms...@linked.in.com>
Co-authored-by: Ye Zhou <ye...@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../network/client/TransportClientFactory.java | 7 +-
.../apache/spark/network/util/TransportConf.java | 13 +-
.../shuffle/sort/BypassMergeSortShuffleWriter.java | 5 +-
.../spark/shuffle/sort/UnsafeShuffleWriter.java | 7 +-
.../scala/org/apache/spark/executor/Executor.scala | 3 +-
.../org/apache/spark/internal/config/package.scala | 29 ++
.../apache/spark/shuffle/ShuffleBlockPusher.scala | 450 +++++++++++++++++++++
.../spark/shuffle/ShuffleWriteProcessor.scala | 19 +-
.../org/apache/spark/shuffle/ShuffleWriter.scala | 3 +
.../spark/shuffle/sort/SortShuffleWriter.scala | 6 +-
.../scala/org/apache/spark/storage/BlockId.scala | 11 +-
.../spark/shuffle/ShuffleBlockPusherSuite.scala | 355 ++++++++++++++++
12 files changed, 896 insertions(+), 12 deletions(-)
diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 24c436a..43408d4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -254,7 +254,7 @@ public class TransportClientFactory implements Closeable {
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
- .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs())
.option(ChannelOption.ALLOCATOR, pooledAllocator);
if (conf.receiveBuf() > 0) {
@@ -280,9 +280,10 @@ public class TransportClientFactory implements Closeable {
// Connect to the remote server
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
- if (!cf.await(conf.connectionTimeoutMs())) {
+ if (!cf.await(conf.connectionCreationTimeoutMs())) {
throw new IOException(
- String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ String.format("Connecting to %s timed out (%s ms)",
+ address, conf.connectionCreationTimeoutMs()));
} else if (cf.cause() != null) {
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
}
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index d305dfa..f051042 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.util;
import java.util.Locale;
import java.util.Properties;
+import java.util.concurrent.TimeUnit;
import com.google.common.primitives.Ints;
import io.netty.util.NettyRuntime;
@@ -31,6 +32,7 @@ public class TransportConf {
private final String SPARK_NETWORK_IO_MODE_KEY;
private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
+ private final String SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY;
private final String SPARK_NETWORK_IO_BACKLOG_KEY;
private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY;
@@ -54,6 +56,7 @@ public class TransportConf {
SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
+ SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY = getConfKey("io.connectionCreationTimeout");
SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY = getConfKey("io.numConnectionsPerPeer");
SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads");
@@ -94,7 +97,7 @@ public class TransportConf {
return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
}
- /** Connect timeout in milliseconds. Default 120 secs. */
+ /** Connection idle timeout in milliseconds. Default 120 secs. */
public int connectionTimeoutMs() {
long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
conf.get("spark.network.timeout", "120s"));
@@ -103,6 +106,14 @@ public class TransportConf {
return (int) defaultTimeoutMs;
}
+ /** Connect creation timeout in milliseconds. Default 30 secs. */
+ public int connectionCreationTimeoutMs() {
+ long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs());
+ long defaultTimeoutMs = JavaUtils.timeStringAsSec(
+ conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY, connectionTimeoutS + "s")) * 1000;
+ return (int) defaultTimeoutMs;
+ }
+
/** Number of concurrent connections between two nodes for fetching data. */
public int numConnectionsPerPeer() {
return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 256789b..3dbee1b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -31,7 +31,6 @@ import scala.Product2;
import scala.Tuple2;
import scala.collection.Iterator;
-import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Closeables;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -178,8 +177,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
}
- @VisibleForTesting
- long[] getPartitionLengths() {
+ @Override
+ public long[] getPartitionLengths() {
return partitionLengths;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 79e38a8..e8f94ba 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -88,6 +88,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@Nullable private MapStatus mapStatus;
@Nullable private ShuffleExternalSorter sorter;
+ @Nullable private long[] partitionLengths;
private long peakMemoryUsedBytes = 0;
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@@ -219,7 +220,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
serOutputStream = null;
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
- final long[] partitionLengths;
try {
partitionLengths = mergeSpills(spills);
} finally {
@@ -543,4 +543,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
channel.close();
}
}
+
+ @Override
+ public long[] getPartitionLengths() {
+ return partitionLengths;
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c58009c..3865c9c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -47,7 +47,7 @@ import org.apache.spark.metrics.source.JVMCPUSource
import org.apache.spark.resource.ResourceInformation
import org.apache.spark.rpc.RpcTimeout
import org.apache.spark.scheduler._
-import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util._
import org.apache.spark.util.io.ChunkedByteBuffer
@@ -325,6 +325,7 @@ private[spark] class Executor(
case NonFatal(e) =>
logWarning("Unable to stop heartbeater", e)
}
+ ShuffleBlockPusher.stop()
threadPool.shutdown()
// Notify plugins that executor is shutting down so they can terminate cleanly
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index adaf92d..84c6647 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2030,4 +2030,33 @@ package object config {
.version("3.1.0")
.doubleConf
.createWithDefault(5)
+
+ private[spark] val SHUFFLE_NUM_PUSH_THREADS =
+ ConfigBuilder("spark.shuffle.push.numPushThreads")
+ .doc("Specify the number of threads in the block pusher pool. These threads assist " +
+ "in creating connections and pushing blocks to remote shuffle services. By default, the " +
+ "threadpool size is equal to the number of spark executor cores.")
+ .version("3.2.0")
+ .intConf
+ .createOptional
+
+ private[spark] val SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH =
+ ConfigBuilder("spark.shuffle.push.maxBlockSizeToPush")
+ .doc("The max size of an individual block to push to the remote shuffle services. Blocks " +
+ "larger than this threshold are not pushed to be merged remotely. These shuffle blocks " +
+ "will be fetched by the executors in the original manner.")
+ .version("3.2.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("1m")
+
+ private[spark] val SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH =
+ ConfigBuilder("spark.shuffle.push.maxBlockBatchSize")
+ .doc("The max size of a batch of shuffle blocks to be grouped into a single push request.")
+ .version("3.2.0")
+ .bytesConf(ByteUnit.BYTE)
+ // Default is 3m because it is greater than 2m which is the default value for
+ // TransportConf#memoryMapBytes. If this defaults to 2m as well it is very likely that each
+ // batch of block will be loaded in memory with memory mapping, which has higher overhead
+ // with small MB sized chunk of data.
+ .createWithDefaultString("3m")
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
new file mode 100644
index 0000000..88d084c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleBlockPusher._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Used for pushing shuffle blocks to remote shuffle services when push shuffle is enabled.
+ * When push shuffle is enabled, it is created after the shuffle writer finishes writing the shuffle
+ * file and initiates the block push process.
+ *
+ * @param conf spark configuration
+ */
+@Since("3.2.0")
+private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
+ private[this] val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH)
+ private[this] val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH)
+ private[this] val maxBytesInFlight =
+ conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024
+ private[this] val maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)
+ private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
+ private[this] var bytesInFlight = 0L
+ private[this] var reqsInFlight = 0
+ private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
+ private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]()
+ private[this] val pushRequests = new Queue[PushRequest]
+ private[this] val errorHandler = createErrorHandler()
+ // VisibleForTesting
+ private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+
+ // VisibleForTesting
+ private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
+ new BlockPushErrorHandler() {
+ // For a connection exception against a particular host, we will stop pushing any
+ // blocks to just that host and continue push blocks to other hosts. So, here push of
+ // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore.
+ override def shouldRetryError(t: Throwable): Boolean = {
+ // If the block is too late, there is no need to retry it
+ !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+ }
+ }
+ }
+
+ /**
+ * Initiates the block push.
+ *
+ * @param dataFile mapper generated shuffle data file
+ * @param partitionLengths array of shuffle block size so we can tell shuffle block
+ * @param dep shuffle dependency to get shuffle ID and the location of remote shuffle
+ * services to push local shuffle blocks
+ * @param mapIndex map index of the shuffle map task
+ */
+ private[shuffle] def initiateBlockPush(
+ dataFile: File,
+ partitionLengths: Array[Long],
+ dep: ShuffleDependency[_, _, _],
+ mapIndex: Int): Unit = {
+ val numPartitions = dep.partitioner.numPartitions
+ val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+ val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile,
+ partitionLengths, dep.getMergerLocs, transportConf)
+ // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
+ // time won't be pushing the same ranges of shuffle partitions.
+ pushRequests ++= Utils.randomize(requests)
+
+ submitTask(() => {
+ pushUpToMax()
+ })
+ }
+
+ /**
+ * Triggers the push. It's a separate method for testing.
+ * VisibleForTesting
+ */
+ protected def submitTask(task: Runnable): Unit = {
+ if (BLOCK_PUSHER_POOL != null) {
+ BLOCK_PUSHER_POOL.execute(task)
+ }
+ }
+
+ /**
+ * Since multiple block push threads could potentially be calling pushUpToMax for the same
+ * mapper, we synchronize access to this method so that only one thread can push blocks for
+ * a given mapper. This helps to simplify access to the shared states. The down side of this
+ * is that we could unnecessarily block other mappers' block pushes if all the threads
+ * are occupied by block pushes from the same mapper.
+ *
+ * This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in how it throttles
+ * the data transfer between shuffle client/server.
+ */
+ private def pushUpToMax(): Unit = synchronized {
+ // Process any outstanding deferred push requests if possible.
+ if (deferredPushRequests.nonEmpty) {
+ for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
+ while (isRemoteBlockPushable(defReqQueue) &&
+ !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+ val request = defReqQueue.dequeue()
+ logDebug(s"Processing deferred push request for $remoteAddress with "
+ + s"${request.blocks.length} blocks")
+ sendRequest(request)
+ if (defReqQueue.isEmpty) {
+ deferredPushRequests -= remoteAddress
+ }
+ }
+ }
+ }
+
+ // Process any regular push requests if possible.
+ while (isRemoteBlockPushable(pushRequests)) {
+ val request = pushRequests.dequeue()
+ val remoteAddress = request.address
+ if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+ logDebug(s"Deferring push request for $remoteAddress with ${request.blocks.size} blocks")
+ deferredPushRequests.getOrElseUpdate(remoteAddress, new Queue[PushRequest]())
+ .enqueue(request)
+ } else {
+ sendRequest(request)
+ }
+ }
+
+ def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
+ pushReqQueue.nonEmpty &&
+ (bytesInFlight == 0 ||
+ (reqsInFlight + 1 <= maxReqsInFlight &&
+ bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
+ }
+
+ // Checks if sending a new push request will exceed the max no. of blocks being pushed to a
+ // given remote address.
+ def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: PushRequest): Boolean = {
+ (numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+ + request.blocks.size) > maxBlocksInFlightPerAddress
+ }
+ }
+
+ /**
+ * Push blocks to remote shuffle server. The callback listener will invoke #pushUpToMax again
+ * to trigger pushing the next batch of blocks once some block transfer is done in the current
+ * batch. This way, we decouple the map task from the block push process, since it is netty
+ * client thread instead of task execution thread which takes care of majority of the block
+ * pushes.
+ */
+ private def sendRequest(request: PushRequest): Unit = {
+ bytesInFlight += request.size
+ reqsInFlight += 1
+ numBlocksInFlightPerAddress(request.address) = numBlocksInFlightPerAddress.getOrElseUpdate(
+ request.address, 0) + request.blocks.length
+
+ val sizeMap = request.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+ val address = request.address
+ val blockIds = request.blocks.map(_._1.toString)
+ val remainingBlocks = new HashSet[String]() ++= blockIds
+
+ val blockPushListener = new BlockFetchingListener {
+ // Initiating a connection and pushing blocks to a remote shuffle service is always handled by
+ // the block-push-threads. We should not initiate the connection creation in the
+ // blockPushListener callbacks which are invoked by the netty eventloop because:
+ // 1. TrasportClient.createConnection(...) blocks for connection to be established and it's
+ // recommended to avoid any blocking operations in the eventloop;
+ // 2. The actual connection creation is a task that gets added to the task queue of another
+ // eventloop which could have eventloops eventually blocking each other.
+ // Once the blockPushListener is notified of the block push success or failure, we
+ // just delegate it to block-push-threads.
+ def handleResult(result: PushResult): Unit = {
+ submitTask(() => {
+ if (updateStateAndCheckIfPushMore(
+ sizeMap(result.blockId), address, remainingBlocks, result)) {
+ pushUpToMax()
+ }
+ })
+ }
+
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ logTrace(s"Push for block $blockId to $address successful.")
+ handleResult(PushResult(blockId, null))
+ }
+
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ // check the message or it's cause to see it needs to be logged.
+ if (!errorHandler.shouldLogError(exception)) {
+ logTrace(s"Pushing block $blockId to $address failed.", exception)
+ } else {
+ logWarning(s"Pushing block $blockId to $address failed.", exception)
+ }
+ handleResult(PushResult(blockId, exception))
+ }
+ }
+ SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
+ address.host, address.port, blockIds.toArray,
+ sliceReqBufferIntoBlockBuffers(request.reqBuffer, request.blocks.map(_._2)),
+ blockPushListener)
+ }
+
+ /**
+ * Given the ManagedBuffer representing all the continuous blocks inside the shuffle data file
+ * for a PushRequest and an array of individual block sizes, load the buffer from disk into
+ * memory and slice it into multiple smaller buffers representing each block.
+ *
+ * With nio ByteBuffer, the individual block buffers share data with the initial in memory
+ * buffer loaded from disk. Thus only one copy of the block data is kept in memory.
+ * @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the continuous blocks in
+ * the shuffle data file for a PushRequest
+ * @param blockSizes Array of block sizes
+ * @return Array of in memory buffer for each individual block
+ */
+ private def sliceReqBufferIntoBlockBuffers(
+ reqBuffer: ManagedBuffer,
+ blockSizes: Seq[Int]): Array[ManagedBuffer] = {
+ if (blockSizes.size == 1) {
+ Array(reqBuffer)
+ } else {
+ val inMemoryBuffer = reqBuffer.nioByteBuffer()
+ val blockOffsets = new Array[Int](blockSizes.size)
+ var offset = 0
+ for (index <- blockSizes.indices) {
+ blockOffsets(index) = offset
+ offset += blockSizes(index)
+ }
+ blockOffsets.zip(blockSizes).map {
+ case (offset, size) =>
+ new NioManagedBuffer(inMemoryBuffer.duplicate()
+ .position(offset)
+ .limit(offset + size).asInstanceOf[ByteBuffer].slice())
+ }.toArray
+ }
+ }
+
+ /**
+ * Updates the stats and based on the previous push result decides whether to push more blocks
+ * or stop.
+ *
+ * @param bytesPushed number of bytes pushed.
+ * @param address address of the remote service
+ * @param remainingBlocks remaining blocks
+ * @param pushResult result of the last push
+ * @return true if more blocks should be pushed; false otherwise.
+ */
+ private def updateStateAndCheckIfPushMore(
+ bytesPushed: Long,
+ address: BlockManagerId,
+ remainingBlocks: HashSet[String],
+ pushResult: PushResult): Boolean = synchronized {
+ remainingBlocks -= pushResult.blockId
+ bytesInFlight -= bytesPushed
+ numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+ if (remainingBlocks.isEmpty) {
+ reqsInFlight -= 1
+ }
+ if (pushResult.failure != null && pushResult.failure.getCause.isInstanceOf[ConnectException]) {
+ // Remove all the blocks for this address just once because removing from pushRequests
+ // is expensive. If there is a ConnectException for the first block, all the subsequent
+ // blocks to that address will fail, so should avoid removing multiple times.
+ if (!unreachableBlockMgrs.contains(address)) {
+ var removed = 0
+ unreachableBlockMgrs.add(address)
+ removed += pushRequests.dequeueAll(req => req.address == address).length
+ removed += deferredPushRequests.remove(address).map(_.length).getOrElse(0)
+ logWarning(s"Received a ConnectException from $address. " +
+ s"Dropping $removed push-requests and " +
+ s"not pushing any more blocks to this address.")
+ }
+ }
+ if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) {
+ logDebug(s"Received after merge is finalized from $address. Not pushing any more blocks.")
+ return false
+ } else {
+ remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty)
+ }
+ }
+
+ /**
+ * Convert the shuffle data file of the current mapper into a list of PushRequest. Basically,
+ * continuous blocks in the shuffle file are grouped into a single request to allow more
+ * efficient read of the block data. Each mapper for a given shuffle will receive the same
+ * list of BlockManagerIds as the target location to push the blocks to. All mappers in the
+ * same shuffle will map shuffle partition ranges to individual target locations in a consistent
+ * manner to make sure each target location receives shuffle blocks belonging to the same set
+ * of partition ranges. 0-length blocks and blocks that are large enough will be skipped.
+ *
+ * @param numPartitions sumber of shuffle partitions in the shuffle file
+ * @param partitionId map index of the current mapper
+ * @param shuffleId shuffleId of current shuffle
+ * @param dataFile shuffle data file
+ * @param partitionLengths array of sizes of blocks in the shuffle data file
+ * @param mergerLocs target locations to push blocks to
+ * @param transportConf transportConf used to create FileSegmentManagedBuffer
+ * @return List of the PushRequest, randomly shuffled.
+ *
+ * VisibleForTesting
+ */
+ private[shuffle] def prepareBlockPushRequests(
+ numPartitions: Int,
+ partitionId: Int,
+ shuffleId: Int,
+ dataFile: File,
+ partitionLengths: Array[Long],
+ mergerLocs: Seq[BlockManagerId],
+ transportConf: TransportConf): Seq[PushRequest] = {
+ var offset = 0L
+ var currentReqSize = 0
+ var currentReqOffset = 0L
+ var currentMergerId = 0
+ val numMergers = mergerLocs.length
+ val requests = new ArrayBuffer[PushRequest]
+ var blocks = new ArrayBuffer[(BlockId, Int)]
+ for (reduceId <- 0 until numPartitions) {
+ val blockSize = partitionLengths(reduceId)
+ logDebug(
+ s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize")
+ // Skip 0-length blocks and blocks that are large enough
+ if (blockSize > 0) {
+ val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
+ numMergers - 1).asInstanceOf[Int]
+ // Start a new PushRequest if the current request goes beyond the max batch size,
+ // or the number of blocks in the current request goes beyond the limit per destination,
+ // or the next block push location is for a different shuffle service, or the next block
+ // exceeds the max block size to push limit. This guarantees that each PushRequest
+ // represents continuous blocks in the shuffle file to be pushed to the same shuffle
+ // service, and does not go beyond existing limitations.
+ if (currentReqSize + blockSize <= maxBlockBatchSize
+ && blocks.size < maxBlocksInFlightPerAddress
+ && mergerId == currentMergerId && blockSize <= maxBlockSizeToPush) {
+ // Add current block to current batch
+ currentReqSize += blockSize.toInt
+ } else {
+ if (blocks.nonEmpty) {
+ // Convert the previous batch into a PushRequest
+ requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
+ createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
+ blocks = new ArrayBuffer[(BlockId, Int)]
+ }
+ // Start a new batch
+ currentReqSize = 0
+ // Set currentReqOffset to -1 so we are able to distinguish between the initial value
+ // of currentReqOffset and when we are about to start a new batch
+ currentReqOffset = -1
+ currentMergerId = mergerId
+ }
+ // Only push blocks under the size limit
+ if (blockSize <= maxBlockSizeToPush) {
+ val blockSizeInt = blockSize.toInt
+ blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt))
+ // Only update currentReqOffset if the current block is the first in the request
+ if (currentReqOffset == -1) {
+ currentReqOffset = offset
+ }
+ if (currentReqSize == 0) {
+ currentReqSize += blockSizeInt
+ }
+ }
+ }
+ offset += blockSize
+ }
+ // Add in the final request
+ if (blocks.nonEmpty) {
+ requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
+ createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
+ }
+ requests.toSeq
+ }
+
+ // Visible for testing
+ protected def createRequestBuffer(
+ conf: TransportConf,
+ dataFile: File,
+ offset: Long,
+ length: Long): ManagedBuffer = {
+ new FileSegmentManagedBuffer(conf, dataFile, offset, length)
+ }
+}
+
+private[spark] object ShuffleBlockPusher {
+
+ /**
+ * A request to push blocks to a remote shuffle service
+ * @param address remote shuffle service location to push blocks to
+ * @param blocks list of block IDs and their sizes
+ * @param reqBuffer a chunk of data in the shuffle data file corresponding to the continuous
+ * blocks represented in this request
+ */
+ private[spark] case class PushRequest(
+ address: BlockManagerId,
+ blocks: Seq[(BlockId, Int)],
+ reqBuffer: ManagedBuffer) {
+ val size = blocks.map(_._2).sum
+ }
+
+ /**
+ * Result of the block push.
+ * @param blockId blockId
+ * @param failure exception if the push was unsuccessful; null otherwise;
+ */
+ private case class PushResult(blockId: String, failure: Throwable)
+
+ private val BLOCK_PUSHER_POOL: ExecutorService = {
+ val conf = SparkEnv.get.conf
+ if (Utils.isPushBasedShuffleEnabled(conf)) {
+ val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
+ .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
+ ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Stop the shuffle pusher pool if it isn't null.
+ */
+ private[spark] def stop(): Unit = {
+ if (BLOCK_PUSHER_POOL != null) {
+ BLOCK_PUSHER_POOL.shutdown()
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
index 1429144..abff650 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
@@ -21,6 +21,7 @@ import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.util.Utils
/**
* The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
@@ -57,7 +58,23 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
createMetricsReporter(context))
writer.write(
rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
- writer.stop(success = true).get
+ val mapStatus = writer.stop(success = true)
+ if (mapStatus.isDefined) {
+ // Initiate shuffle push process if push based shuffle is enabled
+ // The map task only takes care of converting the shuffle data file into multiple
+ // block push requests. It delegates pushing the blocks to a different thread-pool -
+ // ShuffleBlockPusher.BLOCK_PUSHER_POOL.
+ if (Utils.isPushBasedShuffleEnabled(SparkEnv.get.conf) && dep.getMergerLocs.nonEmpty) {
+ manager.shuffleBlockResolver match {
+ case resolver: IndexShuffleBlockResolver =>
+ val dataFile = resolver.getDataFile(dep.shuffleId, mapId)
+ new ShuffleBlockPusher(SparkEnv.get.conf)
+ .initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index)
+ case _ =>
+ }
+ }
+ }
+ mapStatus.get
} catch {
case e: Exception =>
try {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index 4cc4ef5..a279b4c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -31,4 +31,7 @@ private[spark] abstract class ShuffleWriter[K, V] {
/** Close this writer, passing along whether the map completed */
def stop(success: Boolean): Option[MapStatus]
+
+ /** Get the lengths of each partition */
+ def getPartitionLengths(): Array[Long]
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 83ebe3e..af8d1e2 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -45,6 +45,8 @@ private[spark] class SortShuffleWriter[K, V, C](
private var mapStatus: MapStatus = null
+ private var partitionLengths: Array[Long] = _
+
private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
/** Write a bunch of records to this task's output */
@@ -67,7 +69,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
dep.shuffleId, mapId, dep.partitioner.numPartitions)
sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
- val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+ partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
}
@@ -93,6 +95,8 @@ private[spark] class SortShuffleWriter[K, V, C](
}
}
}
+
+ override def getPartitionLengths(): Array[Long] = partitionLengths
}
private[spark] object SortShuffleWriter {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7b084e7..73bf809 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -20,7 +20,7 @@ package org.apache.spark.storage
import java.util.UUID
import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
/**
* :: DeveloperApi ::
@@ -81,6 +81,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
}
+@Since("3.2.0")
+@DeveloperApi
+case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
+ override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
+}
+
@DeveloperApi
case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
@@ -122,6 +128,7 @@ object BlockId {
val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
+ val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -140,6 +147,8 @@ object BlockId {
ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
+ case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
+ ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
case BROADCAST(broadcastId, field) =>
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
case TASKRESULT(taskId) =>
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
new file mode 100644
index 0000000..cc561e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -0,0 +1,355 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: BlockStoreClient = _
+
+ private var conf: SparkConf = _
+ private var pushedBlocks = new ArrayBuffer[String]
+
+ override def beforeEach(): Unit = {
+ super.beforeEach()
+ conf = new SparkConf(loadDefaults = false)
+ MockitoAnnotations.initMocks(this)
+ when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+ when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+ when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1)))
+ conf.set("spark.shuffle.push.based.enabled", "true")
+ conf.set("spark.shuffle.service.enabled", "true")
+ // Set the env because the shuffler writer gets the shuffle client instance from the env.
+ val mockEnv = mock(classOf[SparkEnv])
+ when(mockEnv.conf).thenReturn(conf)
+ when(mockEnv.blockManager).thenReturn(blockManager)
+ SparkEnv.set(mockEnv)
+ when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+ }
+
+ override def afterEach(): Unit = {
+ pushedBlocks.clear()
+ super.afterEach()
+ }
+
+ private def interceptPushedBlocksForSuccess(): Unit = {
+ when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ pushedBlocks ++= blocks
+ val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+ val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+ blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+ })
+ })
+ }
+
+ private def verifyPushRequests(
+ pushRequests: Seq[PushRequest],
+ expectedSizes: Seq[Int]): Unit = {
+ (pushRequests, expectedSizes).zipped.foreach((req, size) => {
+ assert(req.size == size)
+ })
+ }
+
+ test("A batch of blocks is limited by maxBlocksBatchSize") {
+ conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+ conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
+ val blockPusher = new TestShuffleBlockPusher(conf)
+ val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+ val largeBlockSize = 2 * 1024 * 1024
+ val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+ mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
+ mock(classOf[TransportConf]))
+ assert(pushRequests.length == 3)
+ verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
+ }
+
+ test("Large blocks are excluded in the preparation") {
+ conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+ val blockPusher = new TestShuffleBlockPusher(conf)
+ val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+ val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+ mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
+ assert(pushRequests.length == 2)
+ verifyPushRequests(pushRequests, Seq(6, 1024))
+ }
+
+ test("Number of blocks in a push request are limited by maxBlocksInFlightPerAddress ") {
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+ val blockPusher = new TestShuffleBlockPusher(conf)
+ val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+ val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+ mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
+ assert(pushRequests.length == 5)
+ verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
+ }
+
+ test("Basic block push") {
+ interceptPushedBlocksForSuccess()
+ val blockPusher = new TestShuffleBlockPusher(conf)
+ blockPusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ blockPusher.runPendingTasks()
+ verify(shuffleClient, times(1))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ ShuffleBlockPusher.stop()
+ }
+
+ test("Large blocks are skipped for push") {
+ conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+ interceptPushedBlocksForSuccess()
+ val pusher = new TestShuffleBlockPusher(conf)
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 1100), dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(1))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+ ShuffleBlockPusher.stop()
+ }
+
+ test("Number of blocks in flight per address are limited by maxBlocksInFlightPerAddress") {
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+ interceptPushedBlocksForSuccess()
+ val pusher = new TestShuffleBlockPusher(conf)
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(8))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ ShuffleBlockPusher.stop()
+ }
+
+ test("Hit maxBlocksInFlightPerAddress limit so that the blocks are deferred") {
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+ var blockPendingResponse : String = null
+ var listener : BlockFetchingListener = null
+ when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ pushedBlocks ++= blocks
+ val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+ val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ // Expecting 2 blocks
+ assert(blocks.length == 2)
+ if (blockPendingResponse == null) {
+ blockPendingResponse = blocks(1)
+ listener = blockFetchListener
+ // Respond with success only for the first block which will cause all the rest of the
+ // blocks to be deferred
+ blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
+ } else {
+ (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+ blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+ })
+ }
+ })
+ val pusher = new TestShuffleBlockPusher(conf)
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(1))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == 2)
+ // this will trigger push of deferred blocks
+ listener.onBlockFetchSuccess(blockPendingResponse, mock(classOf[ManagedBuffer]))
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(4))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == 8)
+ ShuffleBlockPusher.stop()
+ }
+
+ test("Number of shuffle blocks grouped in a single push request is limited by " +
+ "maxBlockBatchSize") {
+ conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+ interceptPushedBlocksForSuccess()
+ val pusher = new TestShuffleBlockPusher(conf)
+ pusher.initiateBlockPush(mock(classOf[File]),
+ Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(4))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+ ShuffleBlockPusher.stop()
+ }
+
+ test("Error retries") {
+ val pusher = new ShuffleBlockPusher(conf)
+ val errorHandler = pusher.createErrorHandler()
+ assert(
+ !errorHandler.shouldRetryError(new RuntimeException(
+ new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+ assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException())))
+ assert(
+ errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+ BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+ assert (errorHandler.shouldRetryError(new Throwable()))
+ }
+
+ test("Error logging") {
+ val pusher = new ShuffleBlockPusher(conf)
+ val errorHandler = pusher.createErrorHandler()
+ assert(
+ !errorHandler.shouldLogError(new RuntimeException(
+ new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+ assert(!errorHandler.shouldLogError(new RuntimeException(
+ new IllegalArgumentException(
+ BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+ assert(errorHandler.shouldLogError(new Throwable()))
+ }
+
+ test("Blocks are continued to push even when a block push fails with collision " +
+ "exception") {
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+ val pusher = new TestShuffleBlockPusher(conf)
+ var failBlock: Boolean = true
+ when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ blocks.foreach(blockId => {
+ if (failBlock) {
+ failBlock = false
+ // Fail the first block with the collision exception.
+ blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
+ new IllegalArgumentException(
+ BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))
+ } else {
+ pushedBlocks += blockId
+ blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
+ }
+ })
+ })
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(8))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.length == 7)
+ }
+
+ test("More blocks are not pushed when a block push fails with too late " +
+ "exception") {
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+ val pusher = new TestShuffleBlockPusher(conf)
+ var failBlock: Boolean = true
+ when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ blocks.foreach(blockId => {
+ if (failBlock) {
+ failBlock = false
+ // Fail the first block with the too late exception.
+ blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
+ new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))
+ } else {
+ pushedBlocks += blockId
+ blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
+ }
+ })
+ })
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(1))
+ .pushBlocks(any(), any(), any(), any(), any())
+ assert(pushedBlocks.isEmpty)
+ }
+
+ test("Connect exceptions remove all the push requests for that host") {
+ when(dependency.getMergerLocs).thenReturn(
+ Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", "client2", 2)))
+ conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+ when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+ .thenAnswer((invocation: InvocationOnMock) => {
+ val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+ pushedBlocks ++= blocks
+ val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ blocks.foreach(blockId => {
+ blockFetchListener.onBlockFetchFailure(
+ blockId, new RuntimeException(new ConnectException()))
+ })
+ })
+ val pusher = new TestShuffleBlockPusher(conf)
+ pusher.initiateBlockPush(
+ mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+ pusher.runPendingTasks()
+ verify(shuffleClient, times(2))
+ .pushBlocks(any(), any(), any(), any(), any())
+ // 2 blocks for each merger locations
+ assert(pushedBlocks.length == 4)
+ assert(pusher.unreachableBlockMgrs.size == 2)
+ }
+
+ private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) {
+ private[this] val tasks = new LinkedBlockingQueue[Runnable]
+
+ override protected def submitTask(task: Runnable): Unit = {
+ tasks.add(task)
+ }
+
+ def runPendingTasks(): Unit = {
+ // This ensures that all the submitted tasks - updateStateAndCheckIfPushMore and pushUpToMax
+ // are run synchronously.
+ while (!tasks.isEmpty) {
+ tasks.take().run()
+ }
+ }
+
+ override protected def createRequestBuffer(
+ conf: TransportConf,
+ dataFile: File,
+ offset: Long,
+ length: Long): ManagedBuffer = {
+ val managedBuffer = mock(classOf[ManagedBuffer])
+ val byteBuffer = new Array[Byte](length.toInt)
+ when(managedBuffer.nioByteBuffer()).thenReturn(ByteBuffer.wrap(byteBuffer))
+ managedBuffer
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org