You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/04/26 05:17:51 UTC

[spark] branch master updated: [SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 38ef477  [SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle
38ef477 is described below

commit 38ef4771d447f6135382ee2767b3f32b96cb1b0e
Author: Venkata krishnan Sowrirajan <vs...@linkedin.com>
AuthorDate: Mon Apr 26 00:17:26 2021 -0500

    [SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle
    
    ### What changes were proposed in this pull request?
    This is one of the patches for SPIP SPARK-30602 for push-based shuffle.
    Summary of changes:
    
    - Introduce `MergeStatus` which tracks the partition level metadata for a merged shuffle partition in the Spark driver
    - Unify `MergeStatus` and `MapStatus` under a single trait to allow code reusing inside `MapOutputTracker`
    - Extend `MapOutputTracker` to support registering / unregistering `MergeStatus`, calculate preferred locations for a shuffle taking into consideration of merged shuffle partitions, and serving reducer requests for block fetching locations with merged shuffle partitions.
    
    The added APIs in `MapOutputTracker` will be used by `DAGScheduler` in SPARK-32920 and by `ShuffleBlockFetcherIterator` in SPARK-32922
    
    ### Why are the changes needed?
    Refer to SPARK-30602
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Added unit tests.
    
    Lead-authored-by: Min Shen mshenlinkedin.com
    Co-authored-by: Chandni Singh chsinghlinkedin.com
    Co-authored-by: Venkata Sowrirajan vsowrirajanlinkedin.com
    
    Closes #30480 from Victsm/SPARK-32921.
    
    Lead-authored-by: Venkata krishnan Sowrirajan <vs...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Co-authored-by: Chandni Singh <si...@gmail.com>
    Co-authored-by: Chandni Singh <ch...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../scala/org/apache/spark/MapOutputTracker.scala  | 670 ++++++++++++++++++---
 .../org/apache/spark/scheduler/DAGScheduler.scala  |   3 +-
 .../org/apache/spark/scheduler/MapStatus.scala     |   8 +-
 .../org/apache/spark/scheduler/MergeStatus.scala   | 113 ++++
 .../org/apache/spark/MapOutputTrackerSuite.scala   | 317 +++++++++-
 .../spark/MapStatusesSerDeserBenchmark.scala       |  10 +-
 .../test/scala/org/apache/spark/ShuffleSuite.scala |   6 +-
 .../apache/spark/storage/BlockManagerSuite.scala   |   4 +-
 8 files changed, 1006 insertions(+), 125 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ce71c2c..b749d7e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -29,13 +29,14 @@ import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream}
+import org.roaringbitmap.RoaringBitmap
 
 import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
 import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
 import org.apache.spark.util._
@@ -49,7 +50,9 @@ import org.apache.spark.util._
  *
  * All public methods of this class are thread-safe.
  */
-private class ShuffleStatus(numPartitions: Int) extends Logging {
+private class ShuffleStatus(
+    numPartitions: Int,
+    numReducers: Int = -1) extends Logging {
 
   private val (readLock, writeLock) = {
     val lock = new ReentrantReadWriteLock()
@@ -87,6 +90,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
   val mapStatuses = new Array[MapStatus](numPartitions)
 
   /**
+   * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the
+   * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for
+   * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array
+   * provides a reducer oriented view of the shuffle status specifically for the results of
+   * merging shuffle partition blocks into per-partition merged shuffle files.
+   */
+  val mergeStatuses = if (numReducers > 0) {
+    new Array[MergeStatus](numReducers)
+  } else {
+    Array.empty[MergeStatus]
+  }
+
+  /**
    * The cached result of serializing the map statuses array. This cache is lazily populated when
    * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
    */
@@ -103,11 +119,23 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
   private[spark] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
 
   /**
+   * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus.
+   */
+  private[this] var cachedSerializedMergeStatus: Array[Byte] = _
+
+  private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _
+
+  /**
    * Counter tracking the number of partitions that have output. This is a performance optimization
    * to avoid having to count the number of non-null entries in the `mapStatuses` array and should
    * be equivalent to`mapStatuses.count(_ ne null)`.
    */
-  private[this] var _numAvailableOutputs: Int = 0
+  private[this] var _numAvailableMapOutputs: Int = 0
+
+  /**
+   * Counter tracking the number of MergeStatus results received so far from the shuffle services.
+   */
+  private[this] var _numAvailableMergeResults: Int = 0
 
   /**
    * Register a map output. If there is already a registered location for the map output then it
@@ -115,7 +143,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
    */
   def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
     if (mapStatuses(mapIndex) == null) {
-      _numAvailableOutputs += 1
+      _numAvailableMapOutputs += 1
       invalidateSerializedMapOutputStatusCache()
     }
     mapStatuses(mapIndex) = status
@@ -149,13 +177,37 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
   def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
     logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
     if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
-      _numAvailableOutputs -= 1
+      _numAvailableMapOutputs -= 1
       mapStatuses(mapIndex) = null
       invalidateSerializedMapOutputStatusCache()
     }
   }
 
   /**
+   * Register a merge result.
+   */
+  def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock {
+    if (mergeStatuses(reduceId) != status) {
+      _numAvailableMergeResults += 1
+      invalidateSerializedMergeOutputStatusCache()
+    }
+    mergeStatuses(reduceId) = status
+  }
+
+  // TODO support updateMergeResult for similar use cases as updateMapOutput
+
+  /**
+   * Remove the merge result which was served by the specified block manager.
+   */
+  def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
+    if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) {
+      _numAvailableMergeResults -= 1
+      mergeStatuses(reduceId) = null
+      invalidateSerializedMergeOutputStatusCache()
+    }
+  }
+
+  /**
    * Removes all shuffle outputs associated with this host. Note that this will also remove
    * outputs which are served by an external shuffle server (if one exists).
    */
@@ -181,18 +233,33 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
   def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
     for (mapIndex <- mapStatuses.indices) {
       if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
-        _numAvailableOutputs -= 1
+        _numAvailableMapOutputs -= 1
         mapStatuses(mapIndex) = null
         invalidateSerializedMapOutputStatusCache()
       }
     }
+    for (reduceId <- mergeStatuses.indices) {
+      if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) {
+        _numAvailableMergeResults -= 1
+        mergeStatuses(reduceId) = null
+        invalidateSerializedMergeOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle map outputs.
+   */
+  def numAvailableMapOutputs: Int = withReadLock {
+    _numAvailableMapOutputs
   }
 
   /**
-   * Number of partitions that have shuffle outputs.
+   * Number of shuffle partitions that have already been merge finalized when push-based
+   * is enabled.
    */
-  def numAvailableOutputs: Int = withReadLock {
-    _numAvailableOutputs
+  def numAvailableMergeResults: Int = withReadLock {
+    _numAvailableMergeResults
   }
 
   /**
@@ -200,19 +267,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
    */
   def findMissingPartitions(): Seq[Int] = withReadLock {
     val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
+    assert(missing.size == numPartitions - _numAvailableMapOutputs,
+      s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}")
     missing
   }
 
   /**
    * Serializes the mapStatuses array into an efficient compressed format. See the comments on
-   * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format.
+   * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format.
    *
    * This method is designed to be called multiple times and implements caching in order to speed
    * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
-   * serialize the map statuses then serialization will only be performed in a single thread and all
-   * other threads will block until the cache is populated.
+   * serialize the map statuses then serialization will only be performed in a single thread and
+   * all other threads will block until the cache is populated.
    */
   def serializedMapStatus(
       broadcastManager: BroadcastManager,
@@ -220,7 +287,6 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
       minBroadcastSize: Int,
       conf: SparkConf): Array[Byte] = {
     var result: Array[Byte] = null
-
     withReadLock {
       if (cachedSerializedMapStatus != null) {
         result = cachedSerializedMapStatus
@@ -229,7 +295,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
 
     if (result == null) withWriteLock {
       if (cachedSerializedMapStatus == null) {
-        val serResult = MapOutputTracker.serializeMapStatuses(
+        val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus](
           mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
         cachedSerializedMapStatus = serResult._1
         cachedSerializedBroadcast = serResult._2
@@ -241,6 +307,47 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
     result
   }
 
+  /**
+   * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format.
+   * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details
+   * on the serialization format.
+   *
+   * This method is designed to be called multiple times and implements caching in order to speed
+   * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
+   * serialize the statuses array then serialization will only be performed in a single thread and
+   * all other threads will block until the cache is populated.
+   */
+  def serializedMapAndMergeStatus(
+      broadcastManager: BroadcastManager,
+      isLocal: Boolean,
+      minBroadcastSize: Int,
+      conf: SparkConf): (Array[Byte], Array[Byte]) = {
+    val mapStatusesBytes: Array[Byte] =
+      serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf)
+    var mergeStatusesBytes: Array[Byte] = null
+
+    withReadLock {
+      if (cachedSerializedMergeStatus != null) {
+        mergeStatusesBytes = cachedSerializedMergeStatus
+      }
+    }
+
+    if (mergeStatusesBytes == null) withWriteLock {
+      if (cachedSerializedMergeStatus == null) {
+        val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus](
+          mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
+        cachedSerializedMergeStatus = serResult._1
+        cachedSerializedBroadcastMergeStatus = serResult._2
+      }
+
+      // The following line has to be outside if statement since it's possible that another
+      // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and
+      // `withWriteLock`.
+      mergeStatusesBytes = cachedSerializedMergeStatus
+    }
+    (mapStatusesBytes, mergeStatusesBytes)
+  }
+
   // Used in testing.
   def hasCachedSerializedBroadcast: Boolean = withReadLock {
     cachedSerializedBroadcast != null
@@ -254,6 +361,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
     f(mapStatuses)
   }
 
+  def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock {
+    f(mergeStatuses)
+  }
+
   /**
    * Clears the cached serialized map output statuses.
    */
@@ -269,14 +380,35 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
     }
     cachedSerializedMapStatus = null
   }
+
+  /**
+   * Clears the cached serialized merge result statuses.
+   */
+  def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock {
+    if (cachedSerializedBroadcastMergeStatus != null) {
+      Utils.tryLogNonFatalError {
+        // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
+        // RPCs to dead executors.
+        cachedSerializedBroadcastMergeStatus.destroy()
+      }
+      cachedSerializedBroadcastMergeStatus = null
+    }
+    cachedSerializedMergeStatus = null
+  }
 }
 
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
+private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int)
+  extends MapOutputTrackerMessage
 private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
 
-private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext)
+private[spark] sealed trait MapOutputTrackerMasterMessage
+private[spark] case class GetMapOutputMessage(shuffleId: Int,
+  context: RpcCallContext) extends MapOutputTrackerMasterMessage
+private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int,
+  context: RpcCallContext) extends MapOutputTrackerMasterMessage
 
 /** RpcEndpoint class for MapOutputTrackerMaster */
 private[spark] class MapOutputTrackerMasterEndpoint(
@@ -288,8 +420,13 @@ private[spark] class MapOutputTrackerMasterEndpoint(
   override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
     case GetMapOutputStatuses(shuffleId: Int) =>
       val hostPort = context.senderAddress.hostPort
-      logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
-      tracker.post(new GetMapOutputMessage(shuffleId, context))
+      logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort")
+      tracker.post(GetMapOutputMessage(shuffleId, context))
+
+    case GetMapAndMergeResultStatuses(shuffleId: Int) =>
+      val hostPort = context.senderAddress.hostPort
+      logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort")
+      tracker.post(GetMapAndMergeOutputMessage(shuffleId, context))
 
     case StopMapOutputTracker =>
       logInfo("MapOutputTrackerMasterEndpoint stopped!")
@@ -368,6 +505,40 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
 
   /**
+   * Called from executors upon fetch failure on an entire merged shuffle reduce partition.
+   * Such failures can happen if the shuffle client fails to fetch the metadata for the given
+   * merged shuffle partition. This method is to get the server URIs and output sizes for each
+   * shuffle block that is merged in the specified merged shuffle block so fetch failure on a
+   * merged shuffle block can fall back to fetching the unmerged blocks.
+   *
+   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+   *         and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
+   *         tuples describing the shuffle blocks that are stored at that block manager.
+   */
+  def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+
+  /**
+   * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is
+   * to get the server URIs and output sizes for each shuffle block that is merged in the specified
+   * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back
+   * to fetching the unmerged blocks.
+   *
+   * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is
+   * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original
+   * blocks part of this merged chunk.
+   *
+   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+   *         and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
+   *         tuples describing the shuffle blocks that are stored at that block manager.
+   */
+  def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int,
+      chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+
+  /**
    * Deletes map output status information for the specified shuffle stage.
    */
   def unregisterShuffle(shuffleId: Int): Unit
@@ -415,8 +586,11 @@ private[spark] class MapOutputTrackerMaster(
 
   private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
-  // requests for map output statuses
-  private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
+  // requests for MapOutputTrackerMasterMessages
+  private val mapOutputTrackerMasterMessages =
+    new LinkedBlockingQueue[MapOutputTrackerMasterMessage]
+
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
 
   // Thread pool used for handling map output status requests. This is a separate thread pool
   // to ensure we don't block the normal dispatcher threads.
@@ -439,31 +613,47 @@ private[spark] class MapOutputTrackerMaster(
     throw new IllegalArgumentException(msg)
   }
 
-  def post(message: GetMapOutputMessage): Unit = {
-    mapOutputRequests.offer(message)
+  def post(message: MapOutputTrackerMasterMessage): Unit = {
+    mapOutputTrackerMasterMessages.offer(message)
   }
 
   /** Message loop used for dispatching messages. */
   private class MessageLoop extends Runnable {
+    private def handleStatusMessage(
+        shuffleId: Int,
+        context: RpcCallContext,
+        needMergeOutput: Boolean): Unit = {
+      val hostPort = context.senderAddress.hostPort
+      val shuffleStatus = shuffleStatuses.get(shuffleId).head
+      logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" +
+        s" output locations for shuffle $shuffleId to $hostPort")
+      if (needMergeOutput) {
+        context.reply(
+          shuffleStatus.
+            serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
+      } else {
+        context.reply(
+          shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
+      }
+    }
+
     override def run(): Unit = {
       try {
         while (true) {
           try {
-            val data = mapOutputRequests.take()
-             if (data == PoisonPill) {
+            val data = mapOutputTrackerMasterMessages.take()
+            if (data == PoisonPill) {
               // Put PoisonPill back so that other MessageLoops can see it.
-              mapOutputRequests.offer(PoisonPill)
+              mapOutputTrackerMasterMessages.offer(PoisonPill)
               return
             }
-            val context = data.context
-            val shuffleId = data.shuffleId
-            val hostPort = context.senderAddress.hostPort
-            logDebug("Handling request to send map output locations for shuffle " + shuffleId +
-              " to " + hostPort)
-            val shuffleStatus = shuffleStatuses.get(shuffleId).head
-            context.reply(
-              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast,
-                conf))
+
+            data match {
+              case GetMapOutputMessage(shuffleId, context) =>
+                handleStatusMessage(shuffleId, context, false)
+              case GetMapAndMergeOutputMessage(shuffleId, context) =>
+                handleStatusMessage(shuffleId, context, true)
+            }
           } catch {
             case NonFatal(e) => logError(e.getMessage, e)
           }
@@ -475,16 +665,22 @@ private[spark] class MapOutputTrackerMaster(
   }
 
   /** A poison endpoint that indicates MessageLoop should exit its message loop. */
-  private val PoisonPill = new GetMapOutputMessage(-99, null)
+  private val PoisonPill = GetMapOutputMessage(-99, null)
 
   // Used only in unit tests.
   private[spark] def getNumCachedSerializedBroadcast: Int = {
     shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
   }
 
-  def registerShuffle(shuffleId: Int, numMaps: Int): Unit = {
-    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
-      throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+  def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = {
+    if (pushBasedShuffleEnabled) {
+      if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) {
+        throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+      }
+    } else {
+      if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
+        throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+      }
     }
   }
 
@@ -524,10 +720,49 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) {
+    shuffleStatuses(shuffleId).addMergeResult(reduceId, status)
+  }
+
+  def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = {
+    statuses.foreach {
+      case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status)
+    }
+  }
+
+  /**
+   * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId
+   * is specified, it will only unregister the merge result if the mapId is part of that merge
+   * result.
+   *
+   * @param shuffleId the shuffleId.
+   * @param reduceId  the reduceId.
+   * @param bmAddress block manager address.
+   * @param mapId     the optional mapId which should be checked to see it was part of the merge
+   *                  result.
+   */
+  def unregisterMergeResult(
+    shuffleId: Int,
+    reduceId: Int,
+    bmAddress: BlockManagerId,
+    mapId: Option[Int] = None) {
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
+        if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) {
+          shuffleStatus.removeMergeResult(reduceId, bmAddress)
+          incrementEpoch()
+        }
+      case None =>
+        throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID")
+    }
+  }
+
   /** Unregister shuffle data */
   def unregisterShuffle(shuffleId: Int): Unit = {
     shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
       shuffleStatus.invalidateSerializedMapOutputStatusCache()
+      shuffleStatus.invalidateSerializedMergeOutputStatusCache()
     }
   }
 
@@ -554,7 +789,12 @@ private[spark] class MapOutputTrackerMaster(
   def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
 
   def getNumAvailableOutputs(shuffleId: Int): Int = {
-    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+    shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0)
+  }
+
+  /** VisibleForTest. Invoked in test only. */
+  private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = {
+    shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0)
   }
 
   /**
@@ -633,7 +873,9 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the map output is
+   * pre-merged, then return the node where the merged block is located if the merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
@@ -641,15 +883,40 @@ private[spark] class MapOutputTrackerMaster(
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps
+            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+            Seq(status.location.host)
+          } else {
+            Nil
+          }
+        }
       } else {
         Nil
       }
+      if (preferredLoc.nonEmpty) {
+        preferredLoc
+      } else {
+        if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
+          dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
+          val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
+            dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
+          if (blockManagerIds.nonEmpty) {
+            blockManagerIds.get.map(_.host)
+          } else {
+            Nil
+          }
+        } else {
+          Nil
+        }
+      }
     } else {
       Nil
     }
@@ -774,8 +1041,25 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
+  // This method is only called in local-mode. Since push based shuffle won't be
+  // enabled in local-mode, this method returns empty list.
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    Seq.empty.toIterator
+  }
+
+  // This method is only called in local-mode. Since push based shuffle won't be
+  // enabled in local-mode, this method returns empty list.
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int,
+      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    Seq.empty.toIterator
+  }
+
   override def stop(): Unit = {
-    mapOutputRequests.offer(PoisonPill)
+    mapOutputTrackerMasterMessages.offer(PoisonPill)
     threadpool.shutdown()
     try {
       sendTracker(StopMapOutputTracker)
@@ -799,6 +1083,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
   val mapStatuses: Map[Int, Array[MapStatus]] =
     new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
 
+  val mergeStatuses: Map[Int, Array[MergeStatus]] =
+    new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala
+
+  private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf)
+
   /**
    * A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching
    * the same shuffle block.
@@ -812,61 +1101,150 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+    // Fetch the map statuses and merge statuses again since they might have already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot identify the list of
+      // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // Use the MergeStatus's partition level bitmap since we are doing partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int,
+      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+    // Fetch the map statuses and merge statuses again since they might have already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, _) = getStatuses(shuffleId, conf)
+    try {
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses,
+        chunkTracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
         throw e
     }
   }
 
   /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
+   * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled
+   * for a given shuffle ID. NOTE: clients MUST synchronize
    * on this array when reading it, because on the driver, we may be changing it in place.
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
-      val startTimeNs = System.nanoTime()
-      fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          try {
-            fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          } catch {
-            case e: SparkException =>
-              throw new MetadataFetchFailedException(shuffleId, -1,
-                s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
-                  e.getCause)
+  private def getStatuses(
+      shuffleId: Int,
+      conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = {
+    if (fetchMergeResult) {
+      val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+      val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull
+
+      if (mapOutputStatuses == null || mergeOutputStatuses == null) {
+        logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them")
+        val startTimeNs = System.nanoTime()
+        fetchingLock.withLock(shuffleId) {
+          var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+          var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull
+          if (fetchedMapStatuses == null || fetchedMergeStatuses == null) {
+            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+            val fetchedBytes =
+              askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId))
+            try {
+              fetchedMapStatuses =
+                MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf)
+              fetchedMergeStatuses =
+                MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf)
+            } catch {
+              case e: SparkException =>
+                throw new MetadataFetchFailedException(shuffleId, -1,
+                  s"Unable to deserialize broadcasted map/merge statuses" +
+                    s" for shuffle $shuffleId: " + e.getCause)
+            }
+            logInfo("Got the map/merge output locations")
+            mapStatuses.put(shuffleId, fetchedMapStatuses)
+            mergeStatuses.put(shuffleId, fetchedMergeStatuses)
           }
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
+          logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+          (fetchedMapStatuses, fetchedMergeStatuses)
         }
-        logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
-          s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
-        fetchedStatuses
+      } else {
+        (mapOutputStatuses, mergeOutputStatuses)
       }
     } else {
-      statuses
+      val statuses = mapStatuses.get(shuffleId).orNull
+      if (statuses == null) {
+        logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+        val startTimeNs = System.nanoTime()
+        fetchingLock.withLock(shuffleId) {
+          var fetchedStatuses = mapStatuses.get(shuffleId).orNull
+          if (fetchedStatuses == null) {
+            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+            val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+            try {
+              fetchedStatuses =
+                MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf)
+            } catch {
+              case e: SparkException =>
+                throw new MetadataFetchFailedException(shuffleId, -1,
+                  s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
+                    e.getCause)
+            }
+            logInfo("Got the map output locations")
+            mapStatuses.put(shuffleId, fetchedStatuses)
+          }
+          logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+          (fetchedStatuses, null)
+        }
+      } else {
+        (statuses, null)
+      }
     }
   }
 
-
   /** Unregister shuffle data. */
   def unregisterShuffle(shuffleId: Int): Unit = {
     mapStatuses.remove(shuffleId)
+    mergeStatuses.remove(shuffleId)
   }
 
   /**
@@ -880,6 +1258,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
         logInfo("Updating epoch to " + newEpoch + " and clearing cache")
         epoch = newEpoch
         mapStatuses.clear()
+        mergeStatuses.clear()
       }
     }
   }
@@ -891,11 +1270,13 @@ private[spark] object MapOutputTracker extends Logging {
   private val DIRECT = 0
   private val BROADCAST = 1
 
-  // Serialize an array of map output locations into an efficient byte format so that we can send
-  // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
-  // generally be pretty compressible because many map outputs will be on the same hostname.
-  def serializeMapStatuses(
-      statuses: Array[MapStatus],
+  private val SHUFFLE_PUSH_MAP_ID = -1
+
+  // Serialize an array of map/merge output locations into an efficient byte format so that we can
+  // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
+  // generally be pretty compressible because many outputs will be on the same hostname.
+  def serializeOutputStatuses[T <: ShuffleOutputStatus](
+      statuses: Array[T],
       broadcastManager: BroadcastManager,
       isLocal: Boolean,
       minBroadcastSize: Int,
@@ -931,15 +1312,16 @@ private[spark] object MapOutputTracker extends Logging {
         oos.close()
       }
       val outArr = out.toByteArray
-      logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
+      logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arr.length)
       (outArr, bcast)
     } else {
       (arr, null)
     }
   }
 
-  // Opposite of serializeMapStatuses.
-  def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = {
+  // Opposite of serializeOutputStatuses.
+  def deserializeOutputStatuses[T <: ShuffleOutputStatus](
+      bytes: Array[Byte], conf: SparkConf): Array[T] = {
     assert (bytes.length > 0)
 
     def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
@@ -958,20 +1340,22 @@ private[spark] object MapOutputTracker extends Logging {
 
     bytes(0) match {
       case DIRECT =>
-        deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
+        deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[T]]
       case BROADCAST =>
         try {
           // deserialize the Broadcast, pull .value array out of it, and then deserialize that
           val bcast = deserializeObject(bytes, 1, bytes.length - 1).
             asInstanceOf[Broadcast[Array[Byte]]]
-          logInfo("Broadcast mapstatuses size = " + bytes.length +
+          logInfo("Broadcast outputstatuses size = " + bytes.length +
             ", actual size = " + bcast.value.length)
           // Important - ignore the DIRECT tag ! Start from offset 1
-          deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
+          deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[T]]
         } catch {
           case e: IOException =>
-            logWarning("Exception encountered during deserializing broadcasted map statuses: ", e)
-            throw new SparkException("Unable to deserialize broadcasted map statuses", e)
+            logWarning("Exception encountered during deserializing broadcasted" +
+              " output statuses: ", e)
+            throw new SparkException("Unable to deserialize broadcasted" +
+              " output statuses", e)
         }
       case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
     }
@@ -983,15 +1367,19 @@ private[spark] object MapOutputTracker extends Logging {
    * stored at that block manager.
    * Note that empty blocks are filtered in the result.
    *
+   * If push-based shuffle is enabled and an array of merge statuses is available, prioritize
+   * the locations of the merged shuffle partitions over unmerged shuffle blocks.
+   *
    * If any of the statuses is null (indicating a missing location due to a failed mapper),
    * throws a FetchFailedException.
    *
    * @param shuffleId Identifier for the shuffle
    * @param startPartition Start of map output partition ID range (included in range)
    * @param endPartition End of map output partition ID range (excluded from range)
-   * @param statuses List of map statuses, indexed by map partition index.
+   * @param mapStatuses List of map statuses, indexed by map partition index.
    * @param startMapIndex Start Map index.
    * @param endMapIndex End Map index.
+   * @param mergeStatuses List of merge statuses, index by reduce ID.
    * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
    *         and the second item is a sequence of (shuffle block id, shuffle block size, map index)
    *         tuples describing the shuffle blocks that are stored at that block manager.
@@ -1000,18 +1388,57 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0
+      && endMapIndex == mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {
+            // If MergeStatus is available for the given partition, add location of the
+            // pre-merged shuffle partition for this partition ID. Here we create a
+            // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is
+            // a merged shuffle block.
+            splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
+              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1))
+            // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
+            // shuffle partition blocks, fetch the original map produced shuffle partition blocks
+            val mapStatusesWithIndex = mapStatuses.zipWithIndex
+            mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex)
+          } else {
+            // If MergeStatus is not available for the given partition, fall back to
+            // fetching all the original mapper shuffle partition blocks
+            mapStatuses.zipWithIndex.toSeq
+          }
+          // Add location for the mapper shuffle partition blocks
+          for ((mapStatus, mapIndex) <- remainingMapStatuses) {
+            validateStatus(mapStatus, shuffleId, partId)
+            val size = mapStatus.getSizeForBlock(partId)
+            if (size != 0) {
+              splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) +=
+                ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex))
+            }
+          }
+      }
+    } else {
+      val iter = mapStatuses.iterator.zipWithIndex
+      for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
+        validateStatus(status, shuffleId, startPartition)
         for (part <- startPartition until endPartition) {
           val size = status.getSizeForBlock(part)
           if (size != 0) {
@@ -1024,4 +1451,47 @@ private[spark] object MapOutputTracker extends Logging {
 
     splitsByAddress.mapValues(_.toSeq).iterator
   }
+
+  /**
+   * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding
+   * to either a merged shuffle partition or a merged shuffle partition chunk, identify
+   * the metadata about the shuffle partition blocks that are merged into the merged shuffle
+   * partition or partition chunk represented by the bitmap.
+   *
+   * @param shuffleId Identifier for the shuffle
+   * @param partitionId The partition ID of the MergeStatus for which we look for the metadata
+   *                    of the merged shuffle partition blocks
+   * @param mapStatuses List of map statuses, indexed by map ID
+   * @param tracker     bitmap containing mapIndexes that belong to the merged block or merged
+   *                    block chunk.
+   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+   *         and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
+   *         describing the shuffle blocks that are stored at that block manager.
+   */
+  def getMapStatusesForMergeStatus(
+      shuffleId: Int,
+      partitionId: Int,
+      mapStatuses: Array[MapStatus],
+      tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null && tracker != null)
+    val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
+    for ((status, mapIndex) <- mapStatuses.zipWithIndex) {
+      // Only add blocks that are merged
+      if (tracker.contains(mapIndex)) {
+        MapOutputTracker.validateStatus(status, shuffleId, partitionId)
+        splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+          ((ShuffleBlockId(shuffleId, status.mapId, partitionId),
+            status.getSizeForBlock(partitionId), mapIndex))
+      }
+    }
+    splitsByAddress.mapValues(_.toSeq).iterator
+  }
+
+  def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = {
+    if (status == null) {
+      val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition"
+      logError(errorMessage)
+      throw new MetadataFetchFailedException(shuffleId, partition, errorMessage)
+    }
+  }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c2e7c4d..a92d9fa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -456,7 +456,8 @@ private[spark] class DAGScheduler(
       // since we can't do it in the RDD constructor because # of partitions is unknown
       logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " +
         s"shuffle ${shuffleDep.shuffleId}")
-      mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
+      mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length,
+        shuffleDep.partitioner.numPartitions)
     }
     stage
   }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 1239c32..07eed76 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -29,11 +29,17 @@ import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.Utils
 
 /**
+ * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing
+ * code to handle MergeStatus inside MapOutputTracker.
+ */
+private[spark] trait ShuffleOutputStatus
+
+/**
  * Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
  * task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing
  * on to the reduce tasks.
  */
-private[spark] sealed trait MapStatus {
+private[spark] sealed trait MapStatus extends ShuffleOutputStatus {
   /** Location where this task output is. */
   def location: BlockManagerId
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
new file mode 100644
index 0000000..77d8f8e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.network.shuffle.protocol.MergeStatuses
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+
+/**
+ * The status for the result of merging shuffle partition blocks per individual shuffle partition
+ * maintained by the scheduler. The scheduler would separate the
+ * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from
+ * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside
+ * MapOutputTracker to be served to the reducers when they start fetching shuffle partition
+ * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged
+ * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]].
+ * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle
+ * partition meta file, we are effectively dividing the metadata for a push-based shuffle into
+ * 2 layers. The scheduler would track the top-level metadata at the shuffle partition level
+ * with MergeStatus, and the shuffle service would maintain the partition level metadata about
+ * how to further divide a merged shuffle partition into multiple chunks with the per-partition
+ * meta file. This helps to reduce the amount of data the scheduler needs to maintain for
+ * push-based shuffle.
+ */
+private[spark] class MergeStatus(
+    private[this] var loc: BlockManagerId,
+    private[this] var mapTracker: RoaringBitmap,
+    private[this] var size: Long)
+  extends Externalizable with ShuffleOutputStatus {
+
+  protected def this() = this(null, null, -1) // For deserialization only
+
+  def location: BlockManagerId = loc
+
+  def totalSize: Long = size
+
+  def tracker: RoaringBitmap = mapTracker
+
+  /**
+   * Get the list of mapper IDs for missing mapper partition blocks that are not merged.
+   * The reducer will use this information to decide which shuffle partition blocks to
+   * fetch in the original way.
+   */
+  def getMissingMaps(numMaps: Int): Seq[Int] = {
+    (0 until numMaps).filter(i => !mapTracker.contains(i))
+  }
+
+  /**
+   * Get the number of missing map outputs for missing mapper partition blocks that are not merged.
+   */
+  def getNumMissingMapOutputs(numMaps: Int): Int = {
+    (0 until numMaps).count(i => !mapTracker.contains(i))
+  }
+
+  override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+    loc.writeExternal(out)
+    mapTracker.writeExternal(out)
+    out.writeLong(size)
+  }
+
+  override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+    loc = BlockManagerId(in)
+    mapTracker = new RoaringBitmap()
+    mapTracker.readExternal(in)
+    size = in.readLong()
+  }
+}
+
+private[spark] object MergeStatus {
+  // Dummy number of reduces for the tests where push based shuffle is not enabled
+  val SHUFFLE_PUSH_DUMMY_NUM_REDUCES = 1
+
+  /**
+   * Separate a MergeStatuses received from an ExternalShuffleService into individual
+   * MergeStatus. The scheduler is responsible for providing the location information
+   * for the given ExternalShuffleService.
+   */
+  def convertMergeStatusesToMergeStatusArr(
+      mergeStatuses: MergeStatuses,
+      loc: BlockManagerId): Seq[(Int, MergeStatus)] = {
+    assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length &&
+      mergeStatuses.bitmaps.length == mergeStatuses.sizes.length)
+    val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port)
+    mergeStatuses.bitmaps.zipWithIndex.map {
+      case (bitmap, index) =>
+        val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index))
+        (mergeStatuses.reduceIds(index), mergeStatus)
+    }
+  }
+
+  def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = {
+    new MergeStatus(loc, bitmap, size)
+  }
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 83fe450..f4b47e2 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -21,17 +21,19 @@ import scala.collection.mutable.ArrayBuffer
 
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
+import org.roaringbitmap.RoaringBitmap
 
 import org.apache.spark.LocalSparkContext._
 import org.apache.spark.broadcast.BroadcastManager
 import org.apache.spark.internal.config._
 import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE}
+import org.apache.spark.internal.config.Tests.IS_TESTING
 import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
-import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
+import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
 
-class MapOutputTrackerSuite extends SparkFunSuite {
+class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
   private val conf = new SparkConf
 
   private def newTrackerMaster(sparkConf: SparkConf = conf) = {
@@ -58,7 +60,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val tracker = newTrackerMaster()
     tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
-    tracker.registerShuffle(10, 2)
+    tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     assert(tracker.containsShuffle(10))
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
@@ -82,7 +84,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val tracker = newTrackerMaster()
     tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
-    tracker.registerShuffle(10, 2)
+    tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     val compressedSize1000 = MapStatus.compressSize(1000L)
     val compressedSize10000 = MapStatus.compressSize(10000L)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -105,7 +107,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val tracker = newTrackerMaster()
     tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
-    tracker.registerShuffle(10, 2)
+    tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     val compressedSize1000 = MapStatus.compressSize(1000L)
     val compressedSize10000 = MapStatus.compressSize(10000L)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -140,7 +142,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     mapWorkerTracker.trackerEndpoint =
       mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
 
-    masterTracker.registerShuffle(10, 1)
+    masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     mapWorkerTracker.updateEpoch(masterTracker.getEpoch)
     // This is expected to fail because no outputs have been registered for the shuffle.
     intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) }
@@ -183,7 +185,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
 
     // Message size should be ~123B, and no exception should be thrown
-    masterTracker.registerShuffle(10, 1)
+    masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     masterTracker.registerMapOutput(10, 0, MapStatus(
       BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5))
     val senderAddress = RpcAddress("localhost", 12345)
@@ -217,7 +219,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     // on hostA with output size 2
     // on hostA with output size 2
     // on hostB with output size 3
-    tracker.registerShuffle(10, 3)
+    tracker.registerShuffle(10, 3, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
         Array(2L), 5))
     tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -260,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       // Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception.
       // Note that the size is hand-selected here because map output statuses are compressed before
       // being sent.
-      masterTracker.registerShuffle(20, 100)
+      masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
       (0 until 100).foreach { i =>
         masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
           BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
@@ -306,7 +308,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     val tracker = newTrackerMaster()
     tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
       new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
-    tracker.registerShuffle(10, 2)
+    tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
 
     val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L))
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
@@ -332,6 +334,219 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     rpcEnv.shutdown()
   }
 
+  test("SPARK-32921: master register and unregister merge result") {
+    conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(IS_TESTING, true)
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+    tracker.registerShuffle(10, 4, 2)
+    assert(tracker.containsShuffle(10))
+    val bitmap = new RoaringBitmap()
+    bitmap.add(0)
+    bitmap.add(1)
+
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+      bitmap, 1000L))
+    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+      bitmap, 1000L))
+    assert(tracker.getNumAvailableMergeResults(10) == 2)
+    tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
+    assert(tracker.getNumAvailableMergeResults(10) == 1)
+    tracker.stop()
+    rpcEnv.shutdown()
+  }
+
+  test("SPARK-32921: get map sizes with merged shuffle") {
+    conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(IS_TESTING, true)
+    val hostname = "localhost"
+    val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+    val masterTracker = newTrackerMaster()
+    masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+    val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+    val slaveTracker = new MapOutputTrackerWorker(conf)
+    slaveTracker.trackerEndpoint =
+      slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+    masterTracker.registerShuffle(10, 4, 1)
+    slaveTracker.updateEpoch(masterTracker.getEpoch)
+    val bitmap = new RoaringBitmap()
+    bitmap.add(0)
+    bitmap.add(1)
+    bitmap.add(3)
+
+    val blockMgrId = BlockManagerId("a", "hostA", 1000)
+    masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+    masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+    masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+    masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+      bitmap, 3000L))
+    slaveTracker.updateEpoch(masterTracker.getEpoch)
+    val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+    assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
+      Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1),
+        (ShuffleBlockId(10, 2, 0), size1000, 2)))))
+
+    masterTracker.stop()
+    slaveTracker.stop()
+    rpcEnv.shutdown()
+    slaveRpcEnv.shutdown()
+  }
+
+  test("SPARK-32921: get map statuses from merged shuffle") {
+    conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(IS_TESTING, true)
+    val hostname = "localhost"
+    val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+    val masterTracker = newTrackerMaster()
+    masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+    val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+    val slaveTracker = new MapOutputTrackerWorker(conf)
+    slaveTracker.trackerEndpoint =
+      slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+    masterTracker.registerShuffle(10, 4, 1)
+    slaveTracker.updateEpoch(masterTracker.getEpoch)
+    // This is expected to fail because no outputs have been registered for the shuffle.
+    intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+    val bitmap = new RoaringBitmap()
+    bitmap.add(0)
+    bitmap.add(1)
+    bitmap.add(2)
+    bitmap.add(3)
+
+    val blockMgrId = BlockManagerId("a", "hostA", 1000)
+    masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+    masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+    masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+    masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+    masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+      bitmap, 4000L))
+    slaveTracker.updateEpoch(masterTracker.getEpoch)
+    val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+    assert(slaveTracker.getMapSizesForMergeResult(10, 0).toSeq ===
+      Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
+        (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2),
+        (ShuffleBlockId(10, 3, 0), size1000, 3)))))
+    masterTracker.stop()
+    slaveTracker.stop()
+    rpcEnv.shutdown()
+    slaveRpcEnv.shutdown()
+  }
+
+  test("SPARK-32921: get map statuses for merged shuffle block chunks") {
+    conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(IS_TESTING, true)
+    val hostname = "localhost"
+    val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+    val masterTracker = newTrackerMaster()
+    masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+    val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+    val slaveTracker = new MapOutputTrackerWorker(conf)
+    slaveTracker.trackerEndpoint =
+      slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+    masterTracker.registerShuffle(10, 4, 1)
+    slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+    val blockMgrId = BlockManagerId("a", "hostA", 1000)
+    masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+    masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+    masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+    masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+    val chunkBitmap = new RoaringBitmap()
+    chunkBitmap.add(0)
+    chunkBitmap.add(2)
+    val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+    assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap).toSeq ===
+      Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
+        (ShuffleBlockId(10, 2, 0), size1000, 2))))
+    )
+    masterTracker.stop()
+    slaveTracker.stop()
+    rpcEnv.shutdown()
+    slaveRpcEnv.shutdown()
+  }
+
+  test("SPARK-32921: getPreferredLocationsForShuffle with MergeStatus") {
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    sc = new SparkContext("local", "test", conf.clone())
+    tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+    // Setup 5 map tasks
+    // on hostA with output size 2
+    // on hostA with output size 2
+    // on hostB with output size 3
+    // on hostB with output size 3
+    // on hostC with output size 1
+    // on hostC with output size 1
+    tracker.registerShuffle(10, 6, 1)
+    tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
+      Array(2L), 5))
+    tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
+      Array(2L), 6))
+    tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000),
+      Array(3L), 7))
+    tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("b", "hostB", 1000),
+      Array(3L), 8))
+    tracker.registerMapOutput(10, 4, MapStatus(BlockManagerId("c", "hostC", 1000),
+      Array(1L), 9))
+    tracker.registerMapOutput(10, 5, MapStatus(BlockManagerId("c", "hostC", 1000),
+      Array(1L), 10))
+
+    val rdd = sc.parallelize(1 to 6, 6).map(num => (num, num).asInstanceOf[Product2[Int, Int]])
+    val mockShuffleDep = mock(classOf[ShuffleDependency[Int, Int, _]])
+    when(mockShuffleDep.shuffleId).thenReturn(10)
+    when(mockShuffleDep.partitioner).thenReturn(new HashPartitioner(1))
+    when(mockShuffleDep.rdd).thenReturn(rdd)
+
+    // Prepare a MergeStatus that merges 4 out of 5 blocks
+    val bitmap80 = new RoaringBitmap()
+    bitmap80.add(0)
+    bitmap80.add(1)
+    bitmap80.add(2)
+    bitmap80.add(3)
+    bitmap80.add(4)
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+      bitmap80, 11))
+
+    val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
+    assert(preferredLocs1.nonEmpty)
+    assert(preferredLocs1.length === 1)
+    assert(preferredLocs1.head === "hostA")
+
+    tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
+    // Prepare another MergeStatus that merges only 1 out of 5 blocks
+    val bitmap20 = new RoaringBitmap()
+    bitmap20.add(0)
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+      bitmap20, 2))
+
+    val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
+    assert(preferredLocs2.nonEmpty)
+    assert(preferredLocs2.length === 2)
+    assert(preferredLocs2 === Seq("hostA", "hostB"))
+
+    tracker.stop()
+    rpcEnv.shutdown()
+  }
+
   test("SPARK-34939: remote fetch using broadcast if broadcasted value is destroyed") {
     val newConf = new SparkConf
     newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
@@ -346,7 +561,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       rpcEnv.stop(masterTracker.trackerEndpoint)
       rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
 
-      masterTracker.registerShuffle(20, 100)
+      masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
       (0 until 100).foreach { i =>
         masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
           BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
@@ -368,9 +583,85 @@ class MapOutputTrackerSuite extends SparkFunSuite {
         shuffleStatus.cachedSerializedBroadcast.destroy(true)
       }
       val err = intercept[SparkException] {
-        MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
+        MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf)
+      }
+      assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses"))
+    }
+  }
+
+  test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") {
+    val newConf = new SparkConf
+    newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
+    newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast
+    newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize
+    newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    newConf.set(IS_TESTING, true)
+
+    // needs TorrentBroadcast so need a SparkContext
+    withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
+      val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+      val rpcEnv = sc.env.rpcEnv
+      val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+      rpcEnv.stop(masterTracker.trackerEndpoint)
+      rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+      val bitmap1 = new RoaringBitmap()
+      bitmap1.add(1)
+
+      masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
+      (0 until 100).foreach { i =>
+        masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+          BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
       }
-      assert(err.getMessage.contains("Unable to deserialize broadcasted map statuses"))
+      masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000),
+        bitmap1, 1000L))
+
+      val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf))
+      val mapWorkerTracker = new MapOutputTrackerWorker(conf)
+      mapWorkerTracker.trackerEndpoint =
+        mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+      val fetchedBytes = mapWorkerTracker.trackerEndpoint
+        .askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20))
+      assert(masterTracker.getNumAvailableMergeResults(20) == 1)
+      assert(masterTracker.getNumAvailableOutputs(20) == 100)
+
+      val mapOutput =
+        MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf)
+      val mergeOutput =
+        MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf)
+      assert(mapOutput.length == 100)
+      assert(mergeOutput.length == 1)
+      mapWorkerTracker.stop()
+      masterTracker.stop()
     }
   }
+
+  test("SPARK-32921: unregister merge result if it is present and contains the map Id") {
+    val rpcEnv = createRpcEnv("test")
+    val tracker = newTrackerMaster()
+    tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+    tracker.registerShuffle(10, 4, 2)
+    assert(tracker.containsShuffle(10))
+    val bitmap1 = new RoaringBitmap()
+    bitmap1.add(0)
+    bitmap1.add(1)
+    tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+      bitmap1, 1000L))
+
+    val bitmap2 = new RoaringBitmap()
+    bitmap2.add(5)
+    bitmap2.add(6)
+    tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+      bitmap2, 1000L))
+    assert(tracker.getNumAvailableMergeResults(10) == 2)
+    tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0))
+    assert(tracker.getNumAvailableMergeResults(10) == 1)
+    tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(1))
+    assert(tracker.getNumAvailableMergeResults(10) == 1)
+    tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(5))
+    assert(tracker.getNumAvailableMergeResults(10) == 0)
+    tracker.stop()
+    rpcEnv.shutdown()
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
index e433f42..d808823 100644
--- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
@@ -19,7 +19,7 @@ package org.apache.spark
 
 import org.apache.spark.benchmark.Benchmark
 import org.apache.spark.benchmark.BenchmarkBase
-import org.apache.spark.scheduler.CompressedMapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, MergeStatus}
 import org.apache.spark.storage.BlockManagerId
 
 /**
@@ -50,7 +50,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
 
     val shuffleId = 10
 
-    tracker.registerShuffle(shuffleId, numMaps)
+    tracker.registerShuffle(shuffleId, numMaps, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     val r = new scala.util.Random(912)
     (0 until numMaps).foreach { i =>
       tracker.registerMapOutput(shuffleId, i,
@@ -66,7 +66,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
     var serializedMapStatusSizes = 0
     var serializedBroadcastSizes = 0
 
-    val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses(
+    val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeOutputStatuses(
       shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize,
       sc.getConf)
     serializedMapStatusSizes = serializedMapStatus.length
@@ -75,12 +75,12 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
     }
 
     benchmark.addCase("Serialization") { _ =>
-      MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager,
+      MapOutputTracker.serializeOutputStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager,
         tracker.isLocal, minBroadcastSize, sc.getConf)
     }
 
     benchmark.addCase("Deserialization") { _ =>
-      val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf)
+      val result = MapOutputTracker.deserializeOutputStatuses(serializedMapStatus, sc.getConf)
       assert(result.length == numMaps)
     }
 
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 56684d9..126faec 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark
 
 import java.io.File
 import java.util.{Locale, Properties}
-import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
+import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService }
 
 import scala.collection.JavaConverters._
 
@@ -33,7 +33,7 @@ import org.apache.spark.internal.config
 import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
 import org.apache.spark.memory.TaskMemoryManager
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 import org.apache.spark.shuffle.ShuffleWriter
 import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId}
@@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
     val shuffleHandle = manager.registerShuffle(0, shuffleDep)
-    mapTrackerMaster.registerShuffle(0, 1)
+    mapTrackerMaster.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
 
     // first attempt -- its successful
     val context1 =
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 055ee0d..707e168 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -53,7 +53,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo
 import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient}
 import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
-import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated}
+import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, MergeStatus, SparkListenerBlockUpdated}
 import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend}
 import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager}
@@ -1956,7 +1956,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent)
     Files.write(bm2.diskBlockManager.getFile(shuffleIndex2).toPath(), shuffleIndexBlockContent)
 
-    mapOutputTracker.registerShuffle(0, 1)
+    mapOutputTracker.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
     val decomManager = new BlockManagerDecommissioner(conf, bm1)
     try {
       mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0))

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org