You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2020/11/20 12:01:55 UTC

[spark] branch master updated: [SPARK-32919][SHUFFLE][TEST-MAVEN][TEST-HADOOP2.7] Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions

This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 8218b48  [SPARK-32919][SHUFFLE][TEST-MAVEN][TEST-HADOOP2.7] Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions
8218b48 is described below

commit 8218b488035049434271dc9e3bd5af45ffadf0fd
Author: Venkata krishnan Sowrirajan <vs...@linkedin.com>
AuthorDate: Fri Nov 20 06:00:30 2020 -0600

    [SPARK-32919][SHUFFLE][TEST-MAVEN][TEST-HADOOP2.7] Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions
    
    ### What changes were proposed in this pull request?
    Driver side changes for coordinating push based shuffle by selecting external shuffle services for merging partitions.
    
    This PR includes changes related to `ShuffleMapStage` preparation which is selection of merger locations and initializing them as part of `ShuffleDependency`.
    
    Currently this code is not used as some of the changes would come subsequently as part of https://issues.apache.org/jira/browse/SPARK-32917 (shuffle blocks push as part of `ShuffleMapTask`), https://issues.apache.org/jira/browse/SPARK-32918 (support for finalize API) and https://issues.apache.org/jira/browse/SPARK-32920 (finalization of push/merge phase). This is why the tests here are also partial, once these above mentioned changes are raised as PR we will have enough tests for DAGS [...]
    
    ### Why are the changes needed?
    Added a new API in `SchedulerBackend` to get merger locations for push based shuffle. This is currently implemented for Yarn and other cluster managers can have separate implementations which is why a new API is introduced.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, user facing config to enable push based shuffle is introduced
    
    ### How was this patch tested?
    Added unit tests partially and some of the changes in DAGScheduler depends on future changes, DAGScheduler tests will be added along with those changes.
    
    Lead-authored-by: Venkata krishnan Sowrirajan vsowrirajanlinkedin.com
    Co-authored-by: Min Shen mshenlinkedin.com
    
    Closes #30164 from venkata91/upstream-SPARK-32919.
    
    Lead-authored-by: Venkata krishnan Sowrirajan <vs...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/Dependency.scala   | 15 +++++
 .../org/apache/spark/internal/config/package.scala | 47 ++++++++++++++++
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 40 +++++++++++++
 .../apache/spark/scheduler/SchedulerBackend.scala  | 13 +++++
 .../org/apache/spark/storage/BlockManagerId.scala  |  2 +
 .../apache/spark/storage/BlockManagerMaster.scala  | 20 +++++++
 .../spark/storage/BlockManagerMasterEndpoint.scala | 65 ++++++++++++++++++++++
 .../spark/storage/BlockManagerMessages.scala       |  6 ++
 .../main/scala/org/apache/spark/util/Utils.scala   |  8 +++
 .../apache/spark/storage/BlockManagerSuite.scala   | 49 +++++++++++++++-
 .../scala/org/apache/spark/util/UtilsSuite.scala   | 12 ++++
 .../scheduler/cluster/YarnSchedulerBackend.scala   | 50 +++++++++++++++--
 12 files changed, 320 insertions(+), 7 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index ba8e4d6..d21b9d9 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -23,6 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor}
+import org.apache.spark.storage.BlockManagerId
 
 /**
  * :: DeveloperApi ::
@@ -95,6 +96,20 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
   val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
     shuffleId, this)
 
+  /**
+   * Stores the location of the list of chosen external shuffle services for handling the
+   * shuffle merge requests from mappers in this shuffle map stage.
+   */
+  private[spark] var mergerLocs: Seq[BlockManagerId] = Nil
+
+  def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
+    if (mergerLocs != null) {
+      this.mergerLocs = mergerLocs
+    }
+  }
+
+  def getMergerLocs: Seq[BlockManagerId] = mergerLocs
+
   _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
   _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
 }
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 4bc4951..b38d0e5 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -1945,4 +1945,51 @@ package object config {
       .version("3.0.1")
       .booleanConf
       .createWithDefault(false)
+
+  private[spark] val PUSH_BASED_SHUFFLE_ENABLED =
+    ConfigBuilder("spark.shuffle.push.enabled")
+      .doc("Set to 'true' to enable push-based shuffle on the client side and this works in " +
+        "conjunction with the server side flag spark.shuffle.server.mergedShuffleFileManagerImpl " +
+        "which needs to be set with the appropriate " +
+        "org.apache.spark.network.shuffle.MergedShuffleFileManager implementation for push-based " +
+        "shuffle to be enabled")
+      .version("3.1.0")
+      .booleanConf
+      .createWithDefault(false)
+
+  private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS =
+    ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations")
+      .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " +
+        "Currently, shuffle push merger locations are nothing but external shuffle services " +
+        "which are responsible for handling pushed blocks and merging them and serving " +
+        "merged blocks for later shuffle fetch.")
+      .version("3.1.0")
+      .intConf
+      .createWithDefault(500)
+
+  private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO =
+    ConfigBuilder("spark.shuffle.push.mergersMinThresholdRatio")
+      .doc("The minimum number of shuffle merger locations required to enable push based " +
+        "shuffle for a stage. This is specified as a ratio of the number of partitions in " +
+        "the child stage. For example, a reduce stage which has 100 partitions and uses the " +
+        "default value 0.05 requires at least 5 unique merger locations to enable push based " +
+        "shuffle. Merger locations are currently defined as external shuffle services.")
+      .version("3.1.0")
+      .doubleConf
+      .createWithDefault(0.05)
+
+  private[spark] val SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD =
+    ConfigBuilder("spark.shuffle.push.mergersMinStaticThreshold")
+      .doc(s"The static threshold for number of shuffle push merger locations should be " +
+        "available in order to enable push based shuffle for a stage. Note this config " +
+        s"works in conjunction with ${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key}. " +
+        "Maximum of spark.shuffle.push.mergersMinStaticThreshold and " +
+        s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} ratio number of mergers needed to " +
+        "enable push based shuffle for a stage. For eg: with 1000 partitions for the child " +
+        "stage with spark.shuffle.push.mergersMinStaticThreshold as 5 and " +
+        s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " +
+        "at least 50 mergers to enable push based shuffle for that stage.")
+      .version("3.1.0")
+      .doubleConf
+      .createWithDefault(5)
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 13b766e..6fb0fb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -249,6 +249,8 @@ private[spark] class DAGScheduler(
   private[spark] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
   taskScheduler.setDAGScheduler(this)
 
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(sc.getConf)
+
   /**
    * Called by the TaskSetManager to report task's starting.
    */
@@ -1252,6 +1254,33 @@ private[spark] class DAGScheduler(
     execCores.map(cores => properties.setProperty(EXECUTOR_CORES_LOCAL_PROPERTY, cores))
   }
 
+  /**
+   * If push based shuffle is enabled, set the shuffle services to be used for the given
+   * shuffle map stage for block push/merge.
+   *
+   * Even with dynamic resource allocation kicking in and significantly reducing the number
+   * of available active executors, we would still be able to get sufficient shuffle service
+   * locations for block push/merge by getting the historical locations of past executors.
+   */
+  private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = {
+    // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize
+    // TODO changes we cannot disable shuffle merge for the retry/reuse cases
+    val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations(
+      stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId)
+
+    if (mergerLocs.nonEmpty) {
+      stage.shuffleDep.setMergerLocs(mergerLocs)
+      logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" +
+        s" ${stage.shuffleDep.getMergerLocs.size} merger locations")
+
+      logDebug("List of shuffle push merger locations " +
+        s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}")
+    } else {
+      logInfo("No available merger locations." +
+        s" Push-based shuffle disabled for $stage (${stage.name})")
+    }
+  }
+
   /** Called when stage's parents are available and we can now do its task. */
   private def submitMissingTasks(stage: Stage, jobId: Int): Unit = {
     logDebug("submitMissingTasks(" + stage + ")")
@@ -1281,6 +1310,12 @@ private[spark] class DAGScheduler(
     stage match {
       case s: ShuffleMapStage =>
         outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1)
+        // Only generate merger location for a given shuffle dependency once. This way, even if
+        // this stage gets retried, it would still be merging blocks using the same set of
+        // shuffle services.
+        if (pushBasedShuffleEnabled) {
+          prepareShuffleServicesForShuffleMapStage(s)
+        }
       case s: ResultStage =>
         outputCommitCoordinator.stageStart(
           stage = s.id, maxPartitionId = s.rdd.partitions.length - 1)
@@ -2027,6 +2062,11 @@ private[spark] class DAGScheduler(
     if (!executorFailureEpoch.contains(execId) || executorFailureEpoch(execId) < currentEpoch) {
       executorFailureEpoch(execId) = currentEpoch
       logInfo(s"Executor lost: $execId (epoch $currentEpoch)")
+      if (pushBasedShuffleEnabled) {
+        // Remove fetchFailed host in the shuffle push merger list for push based shuffle
+        hostToUnregisterOutputs.foreach(
+          host => blockManagerMaster.removeShufflePushMergerLocation(host))
+      }
       blockManagerMaster.removeExecutor(execId)
       clearCacheLocs()
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index a566d0a..b2acdb3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.scheduler
 
 import org.apache.spark.resource.ResourceProfile
+import org.apache.spark.storage.BlockManagerId
 
 /**
  * A backend interface for scheduling systems that allows plugging in different ones under
@@ -92,4 +93,16 @@ private[spark] trait SchedulerBackend {
    */
   def maxNumConcurrentTasks(rp: ResourceProfile): Int
 
+  /**
+   * Get the list of host locations for push based shuffle
+   *
+   * Currently push based shuffle is disabled for both stage retry and stage reuse cases
+   * (for eg: in the case where few partitions are lost due to failure). Hence this method
+   * should be invoked only once for a ShuffleDependency.
+   * @return List of external shuffle services locations
+   */
+  def getShufflePushMergerLocations(
+      numPartitions: Int,
+      resourceProfileId: Int): Seq[BlockManagerId] = Nil
+
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 49e32d0..c6a4457 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -145,4 +145,6 @@ private[spark] object BlockManagerId {
   def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
     blockManagerIdCache.get(id)
   }
+
+  private[spark] val SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"
 }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index f544d47..fe1a5aef 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -125,6 +125,26 @@ class BlockManagerMaster(
     driverEndpoint.askSync[Seq[BlockManagerId]](GetPeers(blockManagerId))
   }
 
+  /**
+   * Get a list of unique shuffle service locations where an executor is successfully
+   * registered in the past for block push/merge with push based shuffle.
+   */
+  def getShufflePushMergerLocations(
+      numMergersNeeded: Int,
+      hostsToFilter: Set[String]): Seq[BlockManagerId] = {
+    driverEndpoint.askSync[Seq[BlockManagerId]](
+      GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
+  }
+
+  /**
+   * Remove the host from the candidate list of shuffle push mergers. This can be
+   * triggered if there is a FetchFailedException on the host
+   * @param host
+   */
+  def removeShufflePushMergerLocation(host: String): Unit = {
+    driverEndpoint.askSync[Seq[BlockManagerId]](RemoveShufflePushMergerLocation(host))
+  }
+
   def getExecutorEndpointRef(executorId: String): Option[RpcEndpointRef] = {
     driverEndpoint.askSync[Option[RpcEndpointRef]](GetExecutorEndpointRef(executorId))
   }
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index a7532a9..4d56551 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -74,6 +74,14 @@ class BlockManagerMasterEndpoint(
   // Mapping from block id to the set of block managers that have the block.
   private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]]
 
+  // Mapping from host name to shuffle (mergers) services where the current app
+  // registered an executor in the past. Older hosts are removed when the
+  // maxRetainedMergerLocations size is reached in favor of newer locations.
+  private val shuffleMergerLocations = new mutable.LinkedHashMap[String, BlockManagerId]()
+
+  // Maximum number of merger locations to cache
+  private val maxRetainedMergerLocations = conf.get(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS)
+
   private val askThreadPool =
     ThreadUtils.newDaemonCachedThreadPool("block-manager-ask-thread-pool", 100)
   private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
@@ -92,6 +100,8 @@ class BlockManagerMasterEndpoint(
 
   val defaultRpcTimeout = RpcUtils.askRpcTimeout(conf)
 
+  private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
+
   logInfo("BlockManagerMasterEndpoint up")
   // same as `conf.get(config.SHUFFLE_SERVICE_ENABLED)
   //   && conf.get(config.SHUFFLE_SERVICE_FETCH_RDD_ENABLED)`
@@ -139,6 +149,12 @@ class BlockManagerMasterEndpoint(
     case GetBlockStatus(blockId, askStorageEndpoints) =>
       context.reply(blockStatus(blockId, askStorageEndpoints))
 
+    case GetShufflePushMergerLocations(numMergersNeeded, hostsToFilter) =>
+      context.reply(getShufflePushMergerLocations(numMergersNeeded, hostsToFilter))
+
+    case RemoveShufflePushMergerLocation(host) =>
+      context.reply(removeShufflePushMergerLocation(host))
+
     case IsExecutorAlive(executorId) =>
       context.reply(blockManagerIdByExecutor.contains(executorId))
 
@@ -360,6 +376,17 @@ class BlockManagerMasterEndpoint(
 
   }
 
+  private def addMergerLocation(blockManagerId: BlockManagerId): Unit = {
+    if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) {
+      val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER,
+        blockManagerId.host, externalShuffleServicePort)
+      if (shuffleMergerLocations.size >= maxRetainedMergerLocations) {
+        shuffleMergerLocations -= shuffleMergerLocations.head._1
+      }
+      shuffleMergerLocations(shuffleServerId.host) = shuffleServerId
+    }
+  }
+
   private def removeExecutor(execId: String): Unit = {
     logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
     blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
@@ -526,6 +553,10 @@ class BlockManagerMasterEndpoint(
 
       blockManagerInfo(id) = new BlockManagerInfo(id, System.currentTimeMillis(),
         maxOnHeapMemSize, maxOffHeapMemSize, storageEndpoint, externalShuffleServiceBlockStatus)
+
+      if (pushBasedShuffleEnabled) {
+        addMergerLocation(id)
+      }
     }
     listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxOnHeapMemSize + maxOffHeapMemSize,
         Some(maxOnHeapMemSize), Some(maxOffHeapMemSize)))
@@ -657,6 +688,40 @@ class BlockManagerMasterEndpoint(
     }
   }
 
+  private def getShufflePushMergerLocations(
+      numMergersNeeded: Int,
+      hostsToFilter: Set[String]): Seq[BlockManagerId] = {
+    val blockManagerHosts = blockManagerIdByExecutor.values.map(_.host).toSet
+    val filteredBlockManagerHosts = blockManagerHosts.filterNot(hostsToFilter.contains(_))
+    val filteredMergersWithExecutors = filteredBlockManagerHosts.map(
+      BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, _, externalShuffleServicePort))
+    // Enough mergers are available as part of active executors list
+    if (filteredMergersWithExecutors.size >= numMergersNeeded) {
+      filteredMergersWithExecutors.toSeq
+    } else {
+      // Delta mergers added from inactive mergers list to the active mergers list
+      val filteredMergersWithExecutorsHosts = filteredMergersWithExecutors.map(_.host)
+      val filteredMergersWithoutExecutors = shuffleMergerLocations.values
+        .filterNot(x => hostsToFilter.contains(x.host))
+        .filterNot(x => filteredMergersWithExecutorsHosts.contains(x.host))
+      val randomFilteredMergersLocations =
+        if (filteredMergersWithoutExecutors.size >
+          numMergersNeeded - filteredMergersWithExecutors.size) {
+          Utils.randomize(filteredMergersWithoutExecutors)
+            .take(numMergersNeeded - filteredMergersWithExecutors.size)
+        } else {
+          filteredMergersWithoutExecutors
+        }
+      filteredMergersWithExecutors.toSeq ++ randomFilteredMergersLocations
+    }
+  }
+
+  private def removeShufflePushMergerLocation(host: String): Unit = {
+    if (shuffleMergerLocations.contains(host)) {
+      shuffleMergerLocations.remove(host)
+    }
+  }
+
   /**
    * Returns an [[RpcEndpointRef]] of the [[BlockManagerReplicaEndpoint]] for sending RPC messages.
    */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index bbc076c..afe416a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -141,4 +141,10 @@ private[spark] object BlockManagerMessages {
   case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
 
   case class IsExecutorAlive(executorId: String) extends ToBlockManagerMaster
+
+  case class GetShufflePushMergerLocations(numMergersNeeded: Int, hostsToFilter: Set[String])
+    extends ToBlockManagerMaster
+
+  case class RemoveShufflePushMergerLocation(host: String) extends ToBlockManagerMaster
+
 }
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index b743ab6..6ccf65b 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -2542,6 +2542,14 @@ private[spark] object Utils extends Logging {
   }
 
   /**
+   * Push based shuffle can only be enabled when external shuffle service is enabled.
+   */
+  def isPushBasedShuffleEnabled(conf: SparkConf): Boolean = {
+    conf.get(PUSH_BASED_SHUFFLE_ENABLED) &&
+      (conf.get(IS_TESTING).getOrElse(false) || conf.get(SHUFFLE_SERVICE_ENABLED))
+  }
+
+  /**
    * Return whether dynamic allocation is enabled in the given conf.
    */
   def isDynamicAllocationEnabled(conf: SparkConf): Boolean = {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 55280fc..144489c 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -100,6 +100,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
       .set(Kryo.KRYO_SERIALIZER_BUFFER_SIZE.key, "1m")
       .set(STORAGE_UNROLL_MEMORY_THRESHOLD, 512L)
       .set(Network.RPC_ASK_TIMEOUT, "5s")
+      .set(PUSH_BASED_SHUFFLE_ENABLED, true)
   }
 
   private def makeSortShuffleManager(): SortShuffleManager = {
@@ -1974,6 +1975,48 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     }
   }
 
+  test("SPARK-32919: Shuffle push merger locations should be bounded with in" +
+    " spark.shuffle.push.retainedMergerLocations") {
+    assert(master.getShufflePushMergerLocations(10, Set.empty).isEmpty)
+    makeBlockManager(100, "execA",
+      transferService = Some(new MockBlockTransferService(10, "hostA")))
+    makeBlockManager(100, "execB",
+      transferService = Some(new MockBlockTransferService(10, "hostB")))
+    makeBlockManager(100, "execC",
+      transferService = Some(new MockBlockTransferService(10, "hostC")))
+    makeBlockManager(100, "execD",
+      transferService = Some(new MockBlockTransferService(10, "hostD")))
+    makeBlockManager(100, "execE",
+      transferService = Some(new MockBlockTransferService(10, "hostA")))
+    assert(master.getShufflePushMergerLocations(10, Set.empty).size == 4)
+    assert(master.getShufflePushMergerLocations(10, Set.empty).map(_.host).sorted ===
+      Seq("hostC", "hostD", "hostA", "hostB").sorted)
+    assert(master.getShufflePushMergerLocations(10, Set("hostB")).size == 3)
+  }
+
+  test("SPARK-32919: Prefer active executor locations for shuffle push mergers") {
+    makeBlockManager(100, "execA",
+      transferService = Some(new MockBlockTransferService(10, "hostA")))
+    makeBlockManager(100, "execB",
+      transferService = Some(new MockBlockTransferService(10, "hostB")))
+    makeBlockManager(100, "execC",
+      transferService = Some(new MockBlockTransferService(10, "hostC")))
+    makeBlockManager(100, "execD",
+      transferService = Some(new MockBlockTransferService(10, "hostD")))
+    makeBlockManager(100, "execE",
+      transferService = Some(new MockBlockTransferService(10, "hostA")))
+    assert(master.getShufflePushMergerLocations(5, Set.empty).size == 4)
+
+    master.removeExecutor("execA")
+    master.removeExecutor("execE")
+
+    assert(master.getShufflePushMergerLocations(3, Set.empty).size == 3)
+    assert(master.getShufflePushMergerLocations(3, Set.empty).map(_.host).sorted ===
+      Seq("hostC", "hostB", "hostD").sorted)
+    assert(master.getShufflePushMergerLocations(4, Set.empty).map(_.host).sorted ===
+      Seq("hostB", "hostA", "hostC", "hostD").sorted)
+  }
+
   test("SPARK-33387 Support ordered shuffle block migration") {
     val blocks: Seq[ShuffleBlockInfo] = Seq(
       ShuffleBlockInfo(1, 0L),
@@ -1995,7 +2038,9 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
     assert(sortedBlocks.sameElements(decomManager.shufflesToMigrate.asScala.map(_._1)))
   }
 
-  class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService {
+  class MockBlockTransferService(
+      val maxFailures: Int,
+      override val hostName: String = "MockBlockTransferServiceHost") extends BlockTransferService {
     var numCalls = 0
     var tempFileManager: DownloadFileManager = null
 
@@ -2013,8 +2058,6 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
 
     override def close(): Unit = {}
 
-    override def hostName: String = { "MockBlockTransferServiceHost" }
-
     override def port: Int = { 63332 }
 
     override def uploadBlock(
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 20624c7..8fb4080 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -41,6 +41,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
+import org.apache.spark.internal.config.Tests.IS_TESTING
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.scheduler.SparkListener
 import org.apache.spark.util.io.ChunkedByteBufferInputStream
@@ -1432,6 +1433,17 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
     }.getMessage
     assert(message.contains(expected))
   }
+
+  test("isPushBasedShuffleEnabled when both PUSH_BASED_SHUFFLE_ENABLED" +
+    " and SHUFFLE_SERVICE_ENABLED are true") {
+    val conf = new SparkConf()
+    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(IS_TESTING, false)
+    assert(Utils.isPushBasedShuffleEnabled(conf) === false)
+    conf.set(SHUFFLE_SERVICE_ENABLED, true)
+    assert(Utils.isPushBasedShuffleEnabled(conf) === true)
+  }
 }
 
 private class SimpleExtension
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index b42bdb9..22002bb3 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler.cluster
 
 import java.util.EnumSet
-import java.util.concurrent.atomic.{AtomicBoolean}
+import java.util.concurrent.atomic.AtomicBoolean
 import javax.servlet.DispatcherType
 
 import scala.concurrent.{ExecutionContext, Future}
@@ -29,14 +29,14 @@ import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId}
 
 import org.apache.spark.SparkContext
 import org.apache.spark.deploy.security.HadoopDelegationTokenManager
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config
+import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.internal.config.UI._
 import org.apache.spark.resource.ResourceProfile
 import org.apache.spark.rpc._
 import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{RpcUtils, ThreadUtils}
+import org.apache.spark.storage.{BlockManagerId, BlockManagerMaster}
+import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils}
 
 /**
  * Abstract Yarn scheduler backend that contains common logic
@@ -80,6 +80,18 @@ private[spark] abstract class YarnSchedulerBackend(
   /** Attempt ID. This is unset for client-mode schedulers */
   private var attemptId: Option[ApplicationAttemptId] = None
 
+  private val blockManagerMaster: BlockManagerMaster = sc.env.blockManager.master
+
+  private val minMergersThresholdRatio =
+    conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO)
+
+  private val minMergersStaticThreshold =
+    conf.get(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD)
+
+  private val maxNumExecutors = conf.get(config.DYN_ALLOCATION_MAX_EXECUTORS)
+
+  private val numExecutors = conf.get(config.EXECUTOR_INSTANCES).getOrElse(0)
+
   /**
    * Bind to YARN. This *must* be done before calling [[start()]].
    *
@@ -161,6 +173,36 @@ private[spark] abstract class YarnSchedulerBackend(
     totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
   }
 
+  override def getShufflePushMergerLocations(
+      numPartitions: Int,
+      resourceProfileId: Int): Seq[BlockManagerId] = {
+    // TODO (SPARK-33481) This is a naive way of calculating numMergersDesired for a stage,
+    // TODO we can use better heuristics to calculate numMergersDesired for a stage.
+    val maxExecutors = if (Utils.isDynamicAllocationEnabled(sc.getConf)) {
+      maxNumExecutors
+    } else {
+      numExecutors
+    }
+    val tasksPerExecutor = sc.resourceProfileManager
+      .resourceProfileFromId(resourceProfileId).maxTasksPerExecutor(sc.conf)
+    val numMergersDesired = math.min(
+      math.max(1, math.ceil(numPartitions / tasksPerExecutor).toInt), maxExecutors)
+    val minMergersNeeded = math.max(minMergersStaticThreshold,
+      math.floor(numMergersDesired * minMergersThresholdRatio).toInt)
+
+    // Request for numMergersDesired shuffle mergers to BlockManagerMasterEndpoint
+    // and if it's less than minMergersNeeded, we disable push based shuffle.
+    val mergerLocations = blockManagerMaster
+      .getShufflePushMergerLocations(numMergersDesired, scheduler.excludedNodes())
+    if (mergerLocations.size < numMergersDesired && mergerLocations.size < minMergersNeeded) {
+      Seq.empty[BlockManagerId]
+    } else {
+      logDebug(s"The number of shuffle mergers desired ${numMergersDesired}" +
+        s" and available locations are ${mergerLocations.length}")
+      mergerLocations
+    }
+  }
+
   /**
    * Add filters to the SparkUI.
    */


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org