You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/16 07:40:56 UTC

[GitHub] [spark] Ngone51 commented on a change in pull request #30480: [SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle

Ngone51 commented on a change in pull request #30480:
URL: https://github.com/apache/spark/pull/30480#discussion_r614588672



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -181,64 +235,141 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
   def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
     for (mapIndex <- mapStatuses.indices) {
       if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
-        _numAvailableOutputs -= 1
+        _numAvailableMapOutputs -= 1
         mapStatuses(mapIndex) = null
         invalidateSerializedMapOutputStatusCache()
       }
     }
+    for (reduceId <- mergeStatuses.indices) {
+      if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) {
+        _numAvailableMergeResults -= 1
+        mergeStatuses(reduceId) = null
+        invalidateSerializedMergeOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle map outputs.
+   */
+  def numAvailableMapOutputs: Int = withReadLock {
+    _numAvailableMapOutputs
   }
 
   /**
-   * Number of partitions that have shuffle outputs.
+   * Number of shuffle partitions that have already been merge finalized when push-based
+   * is enabled.
    */
-  def numAvailableOutputs: Int = withReadLock {
-    _numAvailableOutputs
+  def numAvailableMergeResults: Int = withReadLock {
+    _numAvailableMergeResults
   }
 
   /**
    * Returns the sequence of partition ids that are missing (i.e. needs to be computed).
    */
   def findMissingPartitions(): Seq[Int] = withReadLock {
     val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
+    assert(missing.size == numPartitions - _numAvailableMapOutputs,
+      s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}")
     missing
   }
 
   /**
-   * Serializes the mapStatuses array into an efficient compressed format. See the comments on
-   * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format.
+   * Serializes the mapStatuses or mergeStatuses array into an efficient compressed format. See
+   * the comments on `MapOutputTracker.serializeOutputStatuses()` for more details on the
+   * serialization format.
    *
    * This method is designed to be called multiple times and implements caching in order to speed
    * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
-   * serialize the map statuses then serialization will only be performed in a single thread and all
-   * other threads will block until the cache is populated.
+   * serialize the statuses array then serialization will only be performed in a single thread and
+   * all other threads will block until the cache is populated.
    */
-  def serializedMapStatus(
+  def serializedOutputStatus(
       broadcastManager: BroadcastManager,
       isLocal: Boolean,
       minBroadcastSize: Int,
-      conf: SparkConf): Array[Byte] = {
-    var result: Array[Byte] = null
+      conf: SparkConf,
+      isMapOnlyOutput: Boolean): (Array[Byte], Array[Byte]) = {

Review comment:
       I think we can rename `isMapOnlyOutput` to `needMergeOutput` and simplify the code below as:
   
   ```scala
   withReadLock {
         if (cachedSerializedMapStatus != null) {
           mapStatuses = cachedSerializedMapStatus
         }
   
         if (needMergeOutput && cachedSerializedMergeStatus != null) {
           mergeStatuses = cachedSerializedMergeStatus
         }
       }
   
       if (mapStatuses == null) {
         mapStatuses =
           serializeAndCacheMapStatuses(broadcastManager, isLocal, minBroadcastSize, conf)
       }
       // If push based shuffle enabled, serialize and cache both Map and Merge Status
       if (needMergeOutput && mergeStatuses == null) {
         mergeStatuses =
           serializeAndCacheMergeStatuses(broadcastManager, isLocal, minBroadcastSize, conf)
       }
   ```

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+    // Fetch the map statuses and merge statuses again since they might have already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot identify the list of
+      // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // Use the MergeStatus's partition level bitmap since we are doing partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int,
+      chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+    // Fetch the map statuses and merge statuses again since they might have already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, _) = getStatuses(shuffleId, conf)
+    try {
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses,
+        chunkTracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
         throw e
     }
   }
 
   /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
+   * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled
+   * for a given shuffle ID. NOTE: clients MUST synchronize
    * on this array when reading it, because on the driver, we may be changing it in place.
    *
    * (It would be nice to remove this restriction in the future.)
    */
-  private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
-      val startTimeNs = System.nanoTime()
-      fetchingLock.withLock(shuffleId) {
-        var fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          try {
-            fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
-          } catch {
-            case e: SparkException =>
-              throw new MetadataFetchFailedException(shuffleId, -1,
-                s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
-                  e.getCause)
+  private def getStatuses(
+      shuffleId: Int,
+      conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = {
+    if (fetchMergeResult) {
+      val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+      val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull
+
+      if (mapOutputStatuses == null || mergeOutputStatuses == null) {
+        logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them")
+        val startTimeNs = System.nanoTime()
+        fetchingLock.withLock(shuffleId) {
+          var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+          var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull
+          if (fetchedMapStatuses == null || fetchedMergeStatuses == null) {
+            logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+            val fetchedBytes =
+              askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId))

Review comment:
       I may miss some discussion after my last discussion, I think this breaches our decision made before:
   
   we won't affect the existing code path in the case of map status only.
   
   
   I think you can return the mapstatus only at the sender side to keep the same behavior?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {
+            // If MergeStatus is available for the given partition, add location of the
+            // pre-merged shuffle partition for this partition ID. Here we create a
+            // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is
+            // a merged shuffle block.
+            splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
+              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1))
+            // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
+            // shuffle partition blocks, fetch the original map produced shuffle partition blocks
+            mergeStatus.getMissingMaps(numMaps).map(mapStatuses.zipWithIndex)

Review comment:
       `mapStatuses.zipWithIndex` would be called for multiple times?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -49,7 +50,10 @@ import org.apache.spark.util._
  *
  * All public methods of this class are thread-safe.
  */
-private class ShuffleStatus(numPartitions: Int) extends Logging {
+private class ShuffleStatus(
+    numPartitions: Int,
+    numReducers: Int,

Review comment:
       I'm thinking we could use `numReducers = -1` to indicate the disabling. Thus we don't need `isPushBasedShuffleEnabled`. But maybe a little bit tricky. It's up to you.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +887,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the map output is
+   * pre-merged, then return the node where the merged block is located if the merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {

Review comment:
       Doesn't this path need to respect `shuffleLocalityEnabled` too?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) {

Review comment:
       nit: `mergeStatuses.exists(_.nonEmpty)` ?
   
   We can skip too if the merged status is empty.

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(

Review comment:
       comment "test only"?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -1000,18 +1403,55 @@ private[spark] object MapOutputTracker extends Logging {
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
-      statuses: Array[MapStatus],
+      mapStatuses: Array[MapStatus],
       startMapIndex : Int,
-      endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
-    assert (statuses != null)
+      endMapIndex: Int,
+      mergeStatuses: Option[Array[MergeStatus]] = None):
+      Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    assert (mapStatuses != null)
     val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
-    val iter = statuses.iterator.zipWithIndex
-    for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
-      if (status == null) {
-        val errorMessage = s"Missing an output location for shuffle $shuffleId"
-        logError(errorMessage)
-        throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
-      } else {
+    // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
+    // partition consists of blocks merged in random order, we are unable to serve map index
+    // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
+    // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
+    // to AQE to handle.
+    // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
+    // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
+    // TODO: map indexes
+    if (mergeStatuses.isDefined && startMapIndex == 0 && endMapIndex == mapStatuses.length) {
+      // We have MergeStatus and full range of mapIds are requested so return a merged block.
+      val numMaps = mapStatuses.length
+      mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach {
+        case (mergeStatus, partId) =>
+          val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {

Review comment:
       I remember Magnet declares that it's able to fall back to the original fetch (using mapstatus) when fetch failure happens. But, here, it looks like we only collect the merged status for those maps only without backup mapstatuses. (Because in my mind, I think we can collect both merged statues and original mapstatus together so that we can fall back if need). How do we plan to support the fallback?

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -812,61 +1115,151 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
       startPartition: Int,
       endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
     logDebug(s"Fetching outputs for shuffle $shuffleId")
-    val statuses = getStatuses(shuffleId, conf)
+    val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
     try {
-      val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
+      val actualEndMapIndex =
+        if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
       logDebug(s"Convert map statuses for shuffle $shuffleId, " +
         s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
       MapOutputTracker.convertMapStatuses(
-        shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
+        shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
+          actualEndMapIndex, Option(mergedOutputStatuses))
     } catch {
       case e: MetadataFetchFailedException =>
         // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
         mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(
+      shuffleId: Int,
+      partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+    logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+    // Fetch the map statuses and merge statuses again since they might have already been
+    // cleared by another task running in the same executor.
+    val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+    try {
+      val mergeStatus = mergeResultStatuses(partitionId)
+      // If the original MergeStatus is no longer available, we cannot identify the list of
+      // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case.
+      MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+      // Use the MergeStatus's partition level bitmap since we are doing partition level fallback
+      MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+        mapOutputStatuses, mergeStatus.tracker)
+    } catch {
+      // We experienced a fetch failure so our mapStatuses cache is outdated; clear it
+      case e: MetadataFetchFailedException =>
+        mapStatuses.clear()
+        mergeStatuses.clear()
+        throw e
+    }
+  }
+
+  override def getMapSizesForMergeResult(

Review comment:
       comment "test only"?
   
   

##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -633,23 +887,50 @@ private[spark] class MapOutputTrackerMaster(
 
   /**
    * Return the preferred hosts on which to run the given map output partition in a given shuffle,
-   * i.e. the nodes that the most outputs for that partition are on.
+   * i.e. the nodes that the most outputs for that partition are on. If the map output is
+   * pre-merged, then return the node where the merged block is located if the merge ratio is
+   * above the threshold.
    *
    * @param dep shuffle dependency object
    * @param partitionId map output partition that we want to read
    * @return a sequence of host names
    */
   def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
       : Seq[String] = {
-    if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
-        dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
-      val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
-        dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
-      if (blockManagerIds.nonEmpty) {
-        blockManagerIds.get.map(_.host)
+    val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+    if (shuffleStatus != null) {
+      // Check if the map output is pre-merged and if the merge ratio is above the threshold.
+      // If so, the location of the merged block is the preferred location.
+      val preferredLoc = if (pushBasedShuffleEnabled) {
+        shuffleStatus.withMergeStatuses { statuses =>
+          val status = statuses(partitionId)
+          val numMaps = dep.rdd.partitions.length
+          if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps
+            <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+            Seq(status.location.host)
+          } else {
+            Nil
+          }
+        }
       } else {
         Nil
       }
+      if (!preferredLoc.isEmpty) {

Review comment:
       nit: `preferredLoc.nonEmpty`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org