You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/01/08 18:22:37 UTC

[spark] branch master updated: [SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d00f069  [SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion
d00f069 is described below

commit d00f0695b7513046e42e47f35b280d7aa494de5b
Author: Chandni Singh <si...@gmail.com>
AuthorDate: Fri Jan 8 12:21:56 2021 -0600

    [SPARK-32917][SHUFFLE][CORE] Adds support for executors to push shuffle blocks after successful map task completion
    
    ### What changes were proposed in this pull request?
    This is the shuffle writer side change where executors can push data to remote shuffle services. This is needed for push-based shuffle - SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    Summary of changes:
    - This adds support for executors to push shuffle blocks after map tasks complete writing shuffle data.
    - This also introduces a timeout specifically for creating connection to remote shuffle services.
    
    ### Why are the changes needed?
    - These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    - The main reason to create a separate connection creation timeout is because the existing `connectionTimeoutMs` is overloaded and is used for connection creation timeouts as well as connection idle timeout. The connection creation timeout should be much lower than the idle timeouts. The default for `connectionTimeoutMs` is 120s. This is quite high for just establishing the connections.  If a shuffle server node is bad then the connection creation will fail within few seconds. However [...]
    
    ### Does this PR introduce _any_ user-facing change?
    Yes. This PR introduces client-side configs for push-based shuffle. If push-based shuffle is turned-off then the users will not see any change.
    
    ### How was this patch tested?
    Added unit tests.
    The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    We have already verified the functionality and the improved performance as documented in the SPIP doc.
    
    Lead-authored-by: Min Shen mshenlinkedin.com
    Co-authored-by: Chandni Singh chsinghlinkedin.com
    Co-authored-by: Ye Zhou yezhoulinkedin.com
    
    Closes #30312 from otterc/SPARK-32917.
    
    Lead-authored-by: Chandni Singh <si...@gmail.com>
    Co-authored-by: Chandni Singh <ch...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linked.in.com>
    Co-authored-by: Ye Zhou <ye...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../network/client/TransportClientFactory.java     |   7 +-
 .../apache/spark/network/util/TransportConf.java   |  13 +-
 .../shuffle/sort/BypassMergeSortShuffleWriter.java |   5 +-
 .../spark/shuffle/sort/UnsafeShuffleWriter.java    |   7 +-
 .../scala/org/apache/spark/executor/Executor.scala |   3 +-
 .../org/apache/spark/internal/config/package.scala |  29 ++
 .../apache/spark/shuffle/ShuffleBlockPusher.scala  | 450 +++++++++++++++++++++
 .../spark/shuffle/ShuffleWriteProcessor.scala      |  19 +-
 .../org/apache/spark/shuffle/ShuffleWriter.scala   |   3 +
 .../spark/shuffle/sort/SortShuffleWriter.scala     |   6 +-
 .../scala/org/apache/spark/storage/BlockId.scala   |  11 +-
 .../spark/shuffle/ShuffleBlockPusherSuite.scala    | 355 ++++++++++++++++
 12 files changed, 896 insertions(+), 12 deletions(-)

diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
index 24c436a..43408d4 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -254,7 +254,7 @@ public class TransportClientFactory implements Closeable {
       // Disable Nagle's Algorithm since we don't want packets to wait
       .option(ChannelOption.TCP_NODELAY, true)
       .option(ChannelOption.SO_KEEPALIVE, true)
-      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
+      .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionCreationTimeoutMs())
       .option(ChannelOption.ALLOCATOR, pooledAllocator);
 
     if (conf.receiveBuf() > 0) {
@@ -280,9 +280,10 @@ public class TransportClientFactory implements Closeable {
     // Connect to the remote server
     long preConnect = System.nanoTime();
     ChannelFuture cf = bootstrap.connect(address);
-    if (!cf.await(conf.connectionTimeoutMs())) {
+    if (!cf.await(conf.connectionCreationTimeoutMs())) {
       throw new IOException(
-        String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+        String.format("Connecting to %s timed out (%s ms)",
+          address, conf.connectionCreationTimeoutMs()));
     } else if (cf.cause() != null) {
       throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
     }
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
index d305dfa..f051042 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java
@@ -19,6 +19,7 @@ package org.apache.spark.network.util;
 
 import java.util.Locale;
 import java.util.Properties;
+import java.util.concurrent.TimeUnit;
 
 import com.google.common.primitives.Ints;
 import io.netty.util.NettyRuntime;
@@ -31,6 +32,7 @@ public class TransportConf {
   private final String SPARK_NETWORK_IO_MODE_KEY;
   private final String SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY;
   private final String SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY;
+  private final String SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY;
   private final String SPARK_NETWORK_IO_BACKLOG_KEY;
   private final String SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY;
   private final String SPARK_NETWORK_IO_SERVERTHREADS_KEY;
@@ -54,6 +56,7 @@ public class TransportConf {
     SPARK_NETWORK_IO_MODE_KEY = getConfKey("io.mode");
     SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY = getConfKey("io.preferDirectBufs");
     SPARK_NETWORK_IO_CONNECTIONTIMEOUT_KEY = getConfKey("io.connectionTimeout");
+    SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY = getConfKey("io.connectionCreationTimeout");
     SPARK_NETWORK_IO_BACKLOG_KEY = getConfKey("io.backLog");
     SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY =  getConfKey("io.numConnectionsPerPeer");
     SPARK_NETWORK_IO_SERVERTHREADS_KEY = getConfKey("io.serverThreads");
@@ -94,7 +97,7 @@ public class TransportConf {
     return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
   }
 
-  /** Connect timeout in milliseconds. Default 120 secs. */
+  /** Connection idle timeout in milliseconds. Default 120 secs. */
   public int connectionTimeoutMs() {
     long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
       conf.get("spark.network.timeout", "120s"));
@@ -103,6 +106,14 @@ public class TransportConf {
     return (int) defaultTimeoutMs;
   }
 
+  /** Connect creation timeout in milliseconds. Default 30 secs. */
+  public int connectionCreationTimeoutMs() {
+    long connectionTimeoutS = TimeUnit.MILLISECONDS.toSeconds(connectionTimeoutMs());
+    long defaultTimeoutMs = JavaUtils.timeStringAsSec(
+      conf.get(SPARK_NETWORK_IO_CONNECTIONCREATIONTIMEOUT_KEY,  connectionTimeoutS + "s")) * 1000;
+    return (int) defaultTimeoutMs;
+  }
+
   /** Number of concurrent connections between two nodes for fetching data. */
   public int numConnectionsPerPeer() {
     return conf.getInt(SPARK_NETWORK_IO_NUMCONNECTIONSPERPEER_KEY, 1);
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 256789b..3dbee1b 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -31,7 +31,6 @@ import scala.Product2;
 import scala.Tuple2;
 import scala.collection.Iterator;
 
-import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.Closeables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -178,8 +177,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     }
   }
 
-  @VisibleForTesting
-  long[] getPartitionLengths() {
+  @Override
+  public long[] getPartitionLengths() {
     return partitionLengths;
   }
 
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 79e38a8..e8f94ba 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -88,6 +88,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   @Nullable private MapStatus mapStatus;
   @Nullable private ShuffleExternalSorter sorter;
+  @Nullable private long[] partitionLengths;
   private long peakMemoryUsedBytes = 0;
 
   /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
@@ -219,7 +220,6 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
     serOutputStream = null;
     final SpillInfo[] spills = sorter.closeAndGetSpills();
     sorter = null;
-    final long[] partitionLengths;
     try {
       partitionLengths = mergeSpills(spills);
     } finally {
@@ -543,4 +543,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
       channel.close();
     }
   }
+
+  @Override
+  public long[] getPartitionLengths() {
+    return partitionLengths;
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c58009c..3865c9c 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -47,7 +47,7 @@ import org.apache.spark.metrics.source.JVMCPUSource
 import org.apache.spark.resource.ResourceInformation
 import org.apache.spark.rpc.RpcTimeout
 import org.apache.spark.scheduler._
-import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.shuffle.{FetchFailedException, ShuffleBlockPusher}
 import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
 import org.apache.spark.util._
 import org.apache.spark.util.io.ChunkedByteBuffer
@@ -325,6 +325,7 @@ private[spark] class Executor(
         case NonFatal(e) =>
           logWarning("Unable to stop heartbeater", e)
       }
+      ShuffleBlockPusher.stop()
       threadPool.shutdown()
 
       // Notify plugins that executor is shutting down so they can terminate cleanly
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index adaf92d..84c6647 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2030,4 +2030,33 @@ package object config {
       .version("3.1.0")
       .doubleConf
       .createWithDefault(5)
+
+  private[spark] val SHUFFLE_NUM_PUSH_THREADS =
+    ConfigBuilder("spark.shuffle.push.numPushThreads")
+      .doc("Specify the number of threads in the block pusher pool. These threads assist " +
+        "in creating connections and pushing blocks to remote shuffle services. By default, the " +
+        "threadpool size is equal to the number of spark executor cores.")
+      .version("3.2.0")
+      .intConf
+      .createOptional
+
+  private[spark] val SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH =
+    ConfigBuilder("spark.shuffle.push.maxBlockSizeToPush")
+      .doc("The max size of an individual block to push to the remote shuffle services. Blocks " +
+       "larger than this threshold are not pushed to be merged remotely. These shuffle blocks " +
+       "will be fetched by the executors in the original manner.")
+      .version("3.2.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("1m")
+
+  private[spark] val SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH =
+    ConfigBuilder("spark.shuffle.push.maxBlockBatchSize")
+      .doc("The max size of a batch of shuffle blocks to be grouped into a single push request.")
+      .version("3.2.0")
+      .bytesConf(ByteUnit.BYTE)
+      // Default is 3m because it is greater than 2m which is the default value for
+      // TransportConf#memoryMapBytes. If this defaults to 2m as well it is very likely that each
+      // batch of block will be loaded in memory with memory mapping, which has higher overhead
+      // with small MB sized chunk of data.
+      .createWithDefaultString("3m")
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
new file mode 100644
index 0000000..88d084c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleBlockPusher._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Used for pushing shuffle blocks to remote shuffle services when push shuffle is enabled.
+ * When push shuffle is enabled, it is created after the shuffle writer finishes writing the shuffle
+ * file and initiates the block push process.
+ *
+ * @param conf spark configuration
+ */
+@Since("3.2.0")
+private[spark] class ShuffleBlockPusher(conf: SparkConf) extends Logging {
+  private[this] val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH)
+  private[this] val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH)
+  private[this] val maxBytesInFlight =
+    conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024
+  private[this] val maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue)
+  private[this] val maxBlocksInFlightPerAddress = conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
+  private[this] var bytesInFlight = 0L
+  private[this] var reqsInFlight = 0
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]()
+  private[this] val deferredPushRequests = new HashMap[BlockManagerId, Queue[PushRequest]]()
+  private[this] val pushRequests = new Queue[PushRequest]
+  private[this] val errorHandler = createErrorHandler()
+  // VisibleForTesting
+  private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+
+  // VisibleForTesting
+  private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
+    new BlockPushErrorHandler() {
+      // For a connection exception against a particular host, we will stop pushing any
+      // blocks to just that host and continue push blocks to other hosts. So, here push of
+      // all blocks will only stop when it is "Too Late". Also see updateStateAndCheckIfPushMore.
+      override def shouldRetryError(t: Throwable): Boolean = {
+        // If the block is too late, there is no need to retry it
+        !Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+      }
+    }
+  }
+
+  /**
+   * Initiates the block push.
+   *
+   * @param dataFile         mapper generated shuffle data file
+   * @param partitionLengths array of shuffle block size so we can tell shuffle block
+   * @param dep              shuffle dependency to get shuffle ID and the location of remote shuffle
+   *                         services to push local shuffle blocks
+   * @param mapIndex      map index of the shuffle map task
+   */
+  private[shuffle] def initiateBlockPush(
+      dataFile: File,
+      partitionLengths: Array[Long],
+      dep: ShuffleDependency[_, _, _],
+      mapIndex: Int): Unit = {
+    val numPartitions = dep.partitioner.numPartitions
+    val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+    val requests = prepareBlockPushRequests(numPartitions, mapIndex, dep.shuffleId, dataFile,
+      partitionLengths, dep.getMergerLocs, transportConf)
+    // Randomize the orders of the PushRequest, so different mappers pushing blocks at the same
+    // time won't be pushing the same ranges of shuffle partitions.
+    pushRequests ++= Utils.randomize(requests)
+
+    submitTask(() => {
+      pushUpToMax()
+    })
+  }
+
+  /**
+   * Triggers the push. It's a separate method for testing.
+   * VisibleForTesting
+   */
+  protected def submitTask(task: Runnable): Unit = {
+    if (BLOCK_PUSHER_POOL != null) {
+      BLOCK_PUSHER_POOL.execute(task)
+    }
+  }
+
+  /**
+   * Since multiple block push threads could potentially be calling pushUpToMax for the same
+   * mapper, we synchronize access to this method so that only one thread can push blocks for
+   * a given mapper. This helps to simplify access to the shared states. The down side of this
+   * is that we could unnecessarily block other mappers' block pushes if all the threads
+   * are occupied by block pushes from the same mapper.
+   *
+   * This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in how it throttles
+   * the data transfer between shuffle client/server.
+   */
+  private def pushUpToMax(): Unit = synchronized {
+    // Process any outstanding deferred push requests if possible.
+    if (deferredPushRequests.nonEmpty) {
+      for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
+        while (isRemoteBlockPushable(defReqQueue) &&
+          !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+          val request = defReqQueue.dequeue()
+          logDebug(s"Processing deferred push request for $remoteAddress with "
+            + s"${request.blocks.length} blocks")
+          sendRequest(request)
+          if (defReqQueue.isEmpty) {
+            deferredPushRequests -= remoteAddress
+          }
+        }
+      }
+    }
+
+    // Process any regular push requests if possible.
+    while (isRemoteBlockPushable(pushRequests)) {
+      val request = pushRequests.dequeue()
+      val remoteAddress = request.address
+      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+        logDebug(s"Deferring push request for $remoteAddress with ${request.blocks.size} blocks")
+        deferredPushRequests.getOrElseUpdate(remoteAddress, new Queue[PushRequest]())
+          .enqueue(request)
+      } else {
+        sendRequest(request)
+      }
+    }
+
+    def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
+      pushReqQueue.nonEmpty &&
+        (bytesInFlight == 0 ||
+          (reqsInFlight + 1 <= maxReqsInFlight &&
+            bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
+    }
+
+    // Checks if sending a new push request will exceed the max no. of blocks being pushed to a
+    // given remote address.
+    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: PushRequest): Boolean = {
+      (numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+        + request.blocks.size) > maxBlocksInFlightPerAddress
+    }
+  }
+
+  /**
+   * Push blocks to remote shuffle server. The callback listener will invoke #pushUpToMax again
+   * to trigger pushing the next batch of blocks once some block transfer is done in the current
+   * batch. This way, we decouple the map task from the block push process, since it is netty
+   * client thread instead of task execution thread which takes care of majority of the block
+   * pushes.
+   */
+  private def sendRequest(request: PushRequest): Unit = {
+    bytesInFlight +=  request.size
+    reqsInFlight += 1
+    numBlocksInFlightPerAddress(request.address) = numBlocksInFlightPerAddress.getOrElseUpdate(
+      request.address, 0) + request.blocks.length
+
+    val sizeMap = request.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
+    val address = request.address
+    val blockIds = request.blocks.map(_._1.toString)
+    val remainingBlocks = new HashSet[String]() ++= blockIds
+
+    val blockPushListener = new BlockFetchingListener {
+      // Initiating a connection and pushing blocks to a remote shuffle service is always handled by
+      // the block-push-threads. We should not initiate the connection creation in the
+      // blockPushListener callbacks which are invoked by the netty eventloop because:
+      // 1. TrasportClient.createConnection(...) blocks for connection to be established and it's
+      // recommended to avoid any blocking operations in the eventloop;
+      // 2. The actual connection creation is a task that gets added to the task queue of another
+      // eventloop which could have eventloops eventually blocking each other.
+      // Once the blockPushListener is notified of the block push success or failure, we
+      // just delegate it to block-push-threads.
+      def handleResult(result: PushResult): Unit = {
+        submitTask(() => {
+          if (updateStateAndCheckIfPushMore(
+            sizeMap(result.blockId), address, remainingBlocks, result)) {
+            pushUpToMax()
+          }
+        })
+      }
+
+      override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+        logTrace(s"Push for block $blockId to $address successful.")
+        handleResult(PushResult(blockId, null))
+      }
+
+      override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+        // check the message or it's cause to see it needs to be logged.
+        if (!errorHandler.shouldLogError(exception)) {
+          logTrace(s"Pushing block $blockId to $address failed.", exception)
+        } else {
+          logWarning(s"Pushing block $blockId to $address failed.", exception)
+        }
+        handleResult(PushResult(blockId, exception))
+      }
+    }
+    SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
+      address.host, address.port, blockIds.toArray,
+      sliceReqBufferIntoBlockBuffers(request.reqBuffer, request.blocks.map(_._2)),
+      blockPushListener)
+  }
+
+  /**
+   * Given the ManagedBuffer representing all the continuous blocks inside the shuffle data file
+   * for a PushRequest and an array of individual block sizes, load the buffer from disk into
+   * memory and slice it into multiple smaller buffers representing each block.
+   *
+   * With nio ByteBuffer, the individual block buffers share data with the initial in memory
+   * buffer loaded from disk. Thus only one copy of the block data is kept in memory.
+   * @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the continuous blocks in
+   *                  the shuffle data file for a PushRequest
+   * @param blockSizes Array of block sizes
+   * @return Array of in memory buffer for each individual block
+   */
+  private def sliceReqBufferIntoBlockBuffers(
+      reqBuffer: ManagedBuffer,
+      blockSizes: Seq[Int]): Array[ManagedBuffer] = {
+    if (blockSizes.size == 1) {
+      Array(reqBuffer)
+    } else {
+      val inMemoryBuffer = reqBuffer.nioByteBuffer()
+      val blockOffsets = new Array[Int](blockSizes.size)
+      var offset = 0
+      for (index <- blockSizes.indices) {
+        blockOffsets(index) = offset
+        offset += blockSizes(index)
+      }
+      blockOffsets.zip(blockSizes).map {
+        case (offset, size) =>
+          new NioManagedBuffer(inMemoryBuffer.duplicate()
+            .position(offset)
+            .limit(offset + size).asInstanceOf[ByteBuffer].slice())
+      }.toArray
+    }
+  }
+
+  /**
+   * Updates the stats and based on the previous push result decides whether to push more blocks
+   * or stop.
+   *
+   * @param bytesPushed     number of bytes pushed.
+   * @param address         address of the remote service
+   * @param remainingBlocks remaining blocks
+   * @param pushResult      result of the last push
+   * @return true if more blocks should be pushed; false otherwise.
+   */
+  private def updateStateAndCheckIfPushMore(
+      bytesPushed: Long,
+      address: BlockManagerId,
+      remainingBlocks: HashSet[String],
+      pushResult: PushResult): Boolean = synchronized {
+    remainingBlocks -= pushResult.blockId
+    bytesInFlight -= bytesPushed
+    numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+    if (remainingBlocks.isEmpty) {
+      reqsInFlight -= 1
+    }
+    if (pushResult.failure != null && pushResult.failure.getCause.isInstanceOf[ConnectException]) {
+      // Remove all the blocks for this address just once because removing from pushRequests
+      // is expensive. If there is a ConnectException for the first block, all the subsequent
+      // blocks to that address will fail, so should avoid removing multiple times.
+      if (!unreachableBlockMgrs.contains(address)) {
+        var removed = 0
+        unreachableBlockMgrs.add(address)
+        removed += pushRequests.dequeueAll(req => req.address == address).length
+        removed += deferredPushRequests.remove(address).map(_.length).getOrElse(0)
+        logWarning(s"Received a ConnectException from $address. " +
+          s"Dropping $removed push-requests and " +
+          s"not pushing any more blocks to this address.")
+      }
+    }
+    if (pushResult.failure != null && !errorHandler.shouldRetryError(pushResult.failure)) {
+      logDebug(s"Received after merge is finalized from $address. Not pushing any more blocks.")
+      return false
+    } else {
+      remainingBlocks.isEmpty && (pushRequests.nonEmpty || deferredPushRequests.nonEmpty)
+    }
+  }
+
+  /**
+   * Convert the shuffle data file of the current mapper into a list of PushRequest. Basically,
+   * continuous blocks in the shuffle file are grouped into a single request to allow more
+   * efficient read of the block data. Each mapper for a given shuffle will receive the same
+   * list of BlockManagerIds as the target location to push the blocks to. All mappers in the
+   * same shuffle will map shuffle partition ranges to individual target locations in a consistent
+   * manner to make sure each target location receives shuffle blocks belonging to the same set
+   * of partition ranges. 0-length blocks and blocks that are large enough will be skipped.
+   *
+   * @param numPartitions sumber of shuffle partitions in the shuffle file
+   * @param partitionId map index of the current mapper
+   * @param shuffleId shuffleId of current shuffle
+   * @param dataFile shuffle data file
+   * @param partitionLengths array of sizes of blocks in the shuffle data file
+   * @param mergerLocs target locations to push blocks to
+   * @param transportConf transportConf used to create FileSegmentManagedBuffer
+   * @return List of the PushRequest, randomly shuffled.
+   *
+   * VisibleForTesting
+   */
+  private[shuffle] def prepareBlockPushRequests(
+      numPartitions: Int,
+      partitionId: Int,
+      shuffleId: Int,
+      dataFile: File,
+      partitionLengths: Array[Long],
+      mergerLocs: Seq[BlockManagerId],
+      transportConf: TransportConf): Seq[PushRequest] = {
+    var offset = 0L
+    var currentReqSize = 0
+    var currentReqOffset = 0L
+    var currentMergerId = 0
+    val numMergers = mergerLocs.length
+    val requests = new ArrayBuffer[PushRequest]
+    var blocks = new ArrayBuffer[(BlockId, Int)]
+    for (reduceId <- 0 until numPartitions) {
+      val blockSize = partitionLengths(reduceId)
+      logDebug(
+        s"Block ${ShufflePushBlockId(shuffleId, partitionId, reduceId)} is of size $blockSize")
+      // Skip 0-length blocks and blocks that are large enough
+      if (blockSize > 0) {
+        val mergerId = math.min(math.floor(reduceId * 1.0 / numPartitions * numMergers),
+          numMergers - 1).asInstanceOf[Int]
+        // Start a new PushRequest if the current request goes beyond the max batch size,
+        // or the number of blocks in the current request goes beyond the limit per destination,
+        // or the next block push location is for a different shuffle service, or the next block
+        // exceeds the max block size to push limit. This guarantees that each PushRequest
+        // represents continuous blocks in the shuffle file to be pushed to the same shuffle
+        // service, and does not go beyond existing limitations.
+        if (currentReqSize + blockSize <= maxBlockBatchSize
+          && blocks.size < maxBlocksInFlightPerAddress
+          && mergerId == currentMergerId && blockSize <= maxBlockSizeToPush) {
+          // Add current block to current batch
+          currentReqSize += blockSize.toInt
+        } else {
+          if (blocks.nonEmpty) {
+            // Convert the previous batch into a PushRequest
+            requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
+              createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
+            blocks = new ArrayBuffer[(BlockId, Int)]
+          }
+          // Start a new batch
+          currentReqSize = 0
+          // Set currentReqOffset to -1 so we are able to distinguish between the initial value
+          // of currentReqOffset and when we are about to start a new batch
+          currentReqOffset = -1
+          currentMergerId = mergerId
+        }
+        // Only push blocks under the size limit
+        if (blockSize <= maxBlockSizeToPush) {
+          val blockSizeInt = blockSize.toInt
+          blocks += ((ShufflePushBlockId(shuffleId, partitionId, reduceId), blockSizeInt))
+          // Only update currentReqOffset if the current block is the first in the request
+          if (currentReqOffset == -1) {
+            currentReqOffset = offset
+          }
+          if (currentReqSize == 0) {
+            currentReqSize += blockSizeInt
+          }
+        }
+      }
+      offset += blockSize
+    }
+    // Add in the final request
+    if (blocks.nonEmpty) {
+      requests += PushRequest(mergerLocs(currentMergerId), blocks.toSeq,
+        createRequestBuffer(transportConf, dataFile, currentReqOffset, currentReqSize))
+    }
+    requests.toSeq
+  }
+
+  // Visible for testing
+  protected def createRequestBuffer(
+      conf: TransportConf,
+      dataFile: File,
+      offset: Long,
+      length: Long): ManagedBuffer = {
+    new FileSegmentManagedBuffer(conf, dataFile, offset, length)
+  }
+}
+
+private[spark] object ShuffleBlockPusher {
+
+  /**
+   * A request to push blocks to a remote shuffle service
+   * @param address remote shuffle service location to push blocks to
+   * @param blocks list of block IDs and their sizes
+   * @param reqBuffer a chunk of data in the shuffle data file corresponding to the continuous
+   *                  blocks represented in this request
+   */
+  private[spark] case class PushRequest(
+    address: BlockManagerId,
+    blocks: Seq[(BlockId, Int)],
+    reqBuffer: ManagedBuffer) {
+    val size = blocks.map(_._2).sum
+  }
+
+  /**
+   * Result of the block push.
+   * @param blockId blockId
+   * @param failure exception if the push was unsuccessful; null otherwise;
+   */
+  private case class PushResult(blockId: String, failure: Throwable)
+
+  private val BLOCK_PUSHER_POOL: ExecutorService = {
+    val conf = SparkEnv.get.conf
+    if (Utils.isPushBasedShuffleEnabled(conf)) {
+      val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
+        .getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
+      ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")
+    } else {
+      null
+    }
+  }
+
+  /**
+   * Stop the shuffle pusher pool if it isn't null.
+   */
+  private[spark] def stop(): Unit = {
+    if (BLOCK_PUSHER_POOL != null) {
+      BLOCK_PUSHER_POOL.shutdown()
+    }
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
index 1429144..abff650 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriteProcessor.scala
@@ -21,6 +21,7 @@ import org.apache.spark.{Partition, ShuffleDependency, SparkEnv, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.util.Utils
 
 /**
  * The interface for customizing shuffle write process. The driver create a ShuffleWriteProcessor
@@ -57,7 +58,23 @@ private[spark] class ShuffleWriteProcessor extends Serializable with Logging {
         createMetricsReporter(context))
       writer.write(
         rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
-      writer.stop(success = true).get
+      val mapStatus = writer.stop(success = true)
+      if (mapStatus.isDefined) {
+        // Initiate shuffle push process if push based shuffle is enabled
+        // The map task only takes care of converting the shuffle data file into multiple
+        // block push requests. It delegates pushing the blocks to a different thread-pool -
+        // ShuffleBlockPusher.BLOCK_PUSHER_POOL.
+        if (Utils.isPushBasedShuffleEnabled(SparkEnv.get.conf) && dep.getMergerLocs.nonEmpty) {
+          manager.shuffleBlockResolver match {
+            case resolver: IndexShuffleBlockResolver =>
+              val dataFile = resolver.getDataFile(dep.shuffleId, mapId)
+              new ShuffleBlockPusher(SparkEnv.get.conf)
+                .initiateBlockPush(dataFile, writer.getPartitionLengths(), dep, partition.index)
+            case _ =>
+          }
+        }
+      }
+      mapStatus.get
     } catch {
       case e: Exception =>
         try {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index 4cc4ef5..a279b4c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -31,4 +31,7 @@ private[spark] abstract class ShuffleWriter[K, V] {
 
   /** Close this writer, passing along whether the map completed */
   def stop(success: Boolean): Option[MapStatus]
+
+  /** Get the lengths of each partition */
+  def getPartitionLengths(): Array[Long]
 }
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 83ebe3e..af8d1e2 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -45,6 +45,8 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var mapStatus: MapStatus = null
 
+  private var partitionLengths: Array[Long] = _
+
   private val writeMetrics = context.taskMetrics().shuffleWriteMetrics
 
   /** Write a bunch of records to this task's output */
@@ -67,7 +69,7 @@ private[spark] class SortShuffleWriter[K, V, C](
     val mapOutputWriter = shuffleExecutorComponents.createMapOutputWriter(
       dep.shuffleId, mapId, dep.partitioner.numPartitions)
     sorter.writePartitionedMapOutput(dep.shuffleId, mapId, mapOutputWriter)
-    val partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
+    partitionLengths = mapOutputWriter.commitAllPartitions().getPartitionLengths
     mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths, mapId)
   }
 
@@ -93,6 +95,8 @@ private[spark] class SortShuffleWriter[K, V, C](
       }
     }
   }
+
+  override def getPartitionLengths(): Array[Long] = partitionLengths
 }
 
 private[spark] object SortShuffleWriter {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 7b084e7..73bf809 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -20,7 +20,7 @@ package org.apache.spark.storage
 import java.util.UUID
 
 import org.apache.spark.SparkException
-import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.annotation.{DeveloperApi, Since}
 
 /**
  * :: DeveloperApi ::
@@ -81,6 +81,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) exten
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
 }
 
+@Since("3.2.0")
+@DeveloperApi
+case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) extends BlockId {
+  override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
+}
+
 @DeveloperApi
 case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
   override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
@@ -122,6 +128,7 @@ object BlockId {
   val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
   val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
+  val SHUFFLE_PUSH = "shufflePush_([0-9]+)_([0-9]+)_([0-9]+)".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -140,6 +147,8 @@ object BlockId {
       ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
     case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
       ShuffleIndexBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt)
+    case SHUFFLE_PUSH(shuffleId, mapIndex, reduceId) =>
+      ShufflePushBlockId(shuffleId.toInt, mapIndex.toInt, reduceId.toInt)
     case BROADCAST(broadcastId, field) =>
       BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
     case TASKRESULT(taskId) =>
diff --git a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
new file mode 100644
index 0000000..cc561e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -0,0 +1,355 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.LinkedBlockingQueue
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: BlockStoreClient = _
+
+  private var conf: SparkConf = _
+  private var pushedBlocks = new ArrayBuffer[String]
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf(loadDefaults = false)
+    MockitoAnnotations.initMocks(this)
+    when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+    when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+    when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", "test-client", 1)))
+    conf.set("spark.shuffle.push.based.enabled", "true")
+    conf.set("spark.shuffle.service.enabled", "true")
+    // Set the env because the shuffler writer gets the shuffle client instance from the env.
+    val mockEnv = mock(classOf[SparkEnv])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.blockManager).thenReturn(blockManager)
+    SparkEnv.set(mockEnv)
+    when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+  }
+
+  override def afterEach(): Unit = {
+    pushedBlocks.clear()
+    super.afterEach()
+  }
+
+  private def interceptPushedBlocksForSuccess(): Unit = {
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+          blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+        })
+      })
+  }
+
+  private def verifyPushRequests(
+      pushRequests: Seq[PushRequest],
+      expectedSizes: Seq[Int]): Unit = {
+    (pushRequests, expectedSizes).zipped.foreach((req, size) => {
+      assert(req.size == size)
+    })
+  }
+
+  test("A batch of blocks is limited by maxBlocksBatchSize") {
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
+    val blockPusher = new TestShuffleBlockPusher(conf)
+    val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    val largeBlockSize = 2 * 1024 * 1024
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+      mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), mergerLocs,
+      mock(classOf[TransportConf]))
+    assert(pushRequests.length == 3)
+    verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
+  }
+
+  test("Large blocks are excluded in the preparation") {
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+    val blockPusher = new TestShuffleBlockPusher(conf)
+    val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+      mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, mock(classOf[TransportConf]))
+    assert(pushRequests.length == 2)
+    verifyPushRequests(pushRequests, Seq(6, 1024))
+  }
+
+  test("Number of blocks in a push request are limited by maxBlocksInFlightPerAddress ") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    val blockPusher = new TestShuffleBlockPusher(conf)
+    val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", loc.host, loc.port))
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0,
+      mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, mock(classOf[TransportConf]))
+    assert(pushRequests.length == 5)
+    verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
+  }
+
+  test("Basic block push") {
+    interceptPushedBlocksForSuccess()
+    val blockPusher = new TestShuffleBlockPusher(conf)
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    blockPusher.runPendingTasks()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Large blocks are skipped for push") {
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+    interceptPushedBlocksForSuccess()
+    val pusher = new TestShuffleBlockPusher(conf)
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 1100), dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of blocks in flight per address are limited by maxBlocksInFlightPerAddress") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    interceptPushedBlocksForSuccess()
+    val pusher = new TestShuffleBlockPusher(conf)
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(8))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Hit maxBlocksInFlightPerAddress limit so that the blocks are deferred") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    var blockPendingResponse : String = null
+    var listener : BlockFetchingListener = null
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        // Expecting 2 blocks
+        assert(blocks.length == 2)
+        if (blockPendingResponse == null) {
+          blockPendingResponse = blocks(1)
+          listener = blockFetchListener
+          // Respond with success only for the first block which will cause all the rest of the
+          // blocks to be deferred
+          blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
+        } else {
+          (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+            blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+          })
+        }
+      })
+    val pusher = new TestShuffleBlockPusher(conf)
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 2)
+    // this will trigger push of deferred blocks
+    listener.onBlockFetchSuccess(blockPendingResponse, mock(classOf[ManagedBuffer]))
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 8)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of shuffle blocks grouped in a single push request is limited by " +
+      "maxBlockBatchSize") {
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+    interceptPushedBlocksForSuccess()
+    val pusher = new TestShuffleBlockPusher(conf)
+    pusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Error retries") {
+    val pusher = new ShuffleBlockPusher(conf)
+    val errorHandler = pusher.createErrorHandler()
+    assert(
+      !errorHandler.shouldRetryError(new RuntimeException(
+        new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(errorHandler.shouldRetryError(new RuntimeException(new ConnectException())))
+    assert(
+      errorHandler.shouldRetryError(new RuntimeException(new IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert (errorHandler.shouldRetryError(new Throwable()))
+  }
+
+  test("Error logging") {
+    val pusher = new ShuffleBlockPusher(conf)
+    val errorHandler = pusher.createErrorHandler()
+    assert(
+      !errorHandler.shouldLogError(new RuntimeException(
+        new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(!errorHandler.shouldLogError(new RuntimeException(
+      new IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert(errorHandler.shouldLogError(new Throwable()))
+  }
+
+  test("Blocks are continued to push even when a block push fails with collision " +
+      "exception") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    val pusher = new TestShuffleBlockPusher(conf)
+    var failBlock: Boolean = true
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        blocks.foreach(blockId => {
+          if (failBlock) {
+            failBlock = false
+            // Fail the first block with the collision exception.
+            blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
+              new IllegalArgumentException(
+                BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX)))
+          } else {
+            pushedBlocks += blockId
+            blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
+          }
+        })
+      })
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(8))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 7)
+  }
+
+  test("More blocks are not pushed when a block push fails with too late " +
+      "exception") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    val pusher = new TestShuffleBlockPusher(conf)
+    var failBlock: Boolean = true
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        blocks.foreach(blockId => {
+          if (failBlock) {
+            failBlock = false
+            // Fail the first block with the too late exception.
+            blockFetchListener.onBlockFetchFailure(blockId, new RuntimeException(
+              new IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)))
+          } else {
+            pushedBlocks += blockId
+            blockFetchListener.onBlockFetchSuccess(blockId, mock(classOf[ManagedBuffer]))
+          }
+        })
+      })
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.isEmpty)
+  }
+
+  test("Connect exceptions remove all the push requests for that host") {
+    when(dependency.getMergerLocs).thenReturn(
+      Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", "client2", 2)))
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val blockFetchListener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        blocks.foreach(blockId => {
+          blockFetchListener.onBlockFetchFailure(
+            blockId, new RuntimeException(new ConnectException()))
+        })
+      })
+    val pusher = new TestShuffleBlockPusher(conf)
+    pusher.initiateBlockPush(
+      mock(classOf[File]), Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0)
+    pusher.runPendingTasks()
+    verify(shuffleClient, times(2))
+      .pushBlocks(any(), any(), any(), any(), any())
+    // 2 blocks for each merger locations
+    assert(pushedBlocks.length == 4)
+    assert(pusher.unreachableBlockMgrs.size == 2)
+  }
+
+  private class TestShuffleBlockPusher(conf: SparkConf) extends ShuffleBlockPusher(conf) {
+   private[this] val tasks = new LinkedBlockingQueue[Runnable]
+
+    override protected def submitTask(task: Runnable): Unit = {
+      tasks.add(task)
+    }
+
+    def runPendingTasks(): Unit = {
+      // This ensures that all the submitted tasks - updateStateAndCheckIfPushMore and pushUpToMax
+      // are run synchronously.
+      while (!tasks.isEmpty) {
+        tasks.take().run()
+      }
+    }
+
+    override protected def createRequestBuffer(
+        conf: TransportConf,
+        dataFile: File,
+        offset: Long,
+        length: Long): ManagedBuffer = {
+      val managedBuffer = mock(classOf[ManagedBuffer])
+      val byteBuffer = new Array[Byte](length.toInt)
+      when(managedBuffer.nioByteBuffer()).thenReturn(ByteBuffer.wrap(byteBuffer))
+      managedBuffer
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org