You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by mr...@apache.org on 2021/04/26 05:17:51 UTC
[spark] branch master updated: [SPARK-32921][SHUFFLE]
MapOutputTracker extensions to support push-based shuffle
This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 38ef477 [SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle
38ef477 is described below
commit 38ef4771d447f6135382ee2767b3f32b96cb1b0e
Author: Venkata krishnan Sowrirajan <vs...@linkedin.com>
AuthorDate: Mon Apr 26 00:17:26 2021 -0500
[SPARK-32921][SHUFFLE] MapOutputTracker extensions to support push-based shuffle
### What changes were proposed in this pull request?
This is one of the patches for SPIP SPARK-30602 for push-based shuffle.
Summary of changes:
- Introduce `MergeStatus` which tracks the partition level metadata for a merged shuffle partition in the Spark driver
- Unify `MergeStatus` and `MapStatus` under a single trait to allow code reusing inside `MapOutputTracker`
- Extend `MapOutputTracker` to support registering / unregistering `MergeStatus`, calculate preferred locations for a shuffle taking into consideration of merged shuffle partitions, and serving reducer requests for block fetching locations with merged shuffle partitions.
The added APIs in `MapOutputTracker` will be used by `DAGScheduler` in SPARK-32920 and by `ShuffleBlockFetcherIterator` in SPARK-32922
### Why are the changes needed?
Refer to SPARK-30602
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Added unit tests.
Lead-authored-by: Min Shen mshenlinkedin.com
Co-authored-by: Chandni Singh chsinghlinkedin.com
Co-authored-by: Venkata Sowrirajan vsowrirajanlinkedin.com
Closes #30480 from Victsm/SPARK-32921.
Lead-authored-by: Venkata krishnan Sowrirajan <vs...@linkedin.com>
Co-authored-by: Min Shen <ms...@linkedin.com>
Co-authored-by: Chandni Singh <si...@gmail.com>
Co-authored-by: Chandni Singh <ch...@linkedin.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../scala/org/apache/spark/MapOutputTracker.scala | 670 ++++++++++++++++++---
.../org/apache/spark/scheduler/DAGScheduler.scala | 3 +-
.../org/apache/spark/scheduler/MapStatus.scala | 8 +-
.../org/apache/spark/scheduler/MergeStatus.scala | 113 ++++
.../org/apache/spark/MapOutputTrackerSuite.scala | 317 +++++++++-
.../spark/MapStatusesSerDeserBenchmark.scala | 10 +-
.../test/scala/org/apache/spark/ShuffleSuite.scala | 6 +-
.../apache/spark/storage/BlockManagerSuite.scala | 4 +-
8 files changed, 1006 insertions(+), 125 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index ce71c2c..b749d7e 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -29,13 +29,14 @@ import scala.reflect.ClassTag
import scala.util.control.NonFatal
import org.apache.commons.io.output.{ByteArrayOutputStream => ApacheByteArrayOutputStream}
+import org.roaringbitmap.RoaringBitmap
import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus}
import org.apache.spark.shuffle.MetadataFetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util._
@@ -49,7 +50,9 @@ import org.apache.spark.util._
*
* All public methods of this class are thread-safe.
*/
-private class ShuffleStatus(numPartitions: Int) extends Logging {
+private class ShuffleStatus(
+ numPartitions: Int,
+ numReducers: Int = -1) extends Logging {
private val (readLock, writeLock) = {
val lock = new ReentrantReadWriteLock()
@@ -87,6 +90,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
val mapStatuses = new Array[MapStatus](numPartitions)
/**
+ * MergeStatus for each shuffle partition when push-based shuffle is enabled. The index of the
+ * array is the shuffle partition id (reduce id). Each value in the array is the MergeStatus for
+ * a shuffle partition, or null if not available. When push-based shuffle is enabled, this array
+ * provides a reducer oriented view of the shuffle status specifically for the results of
+ * merging shuffle partition blocks into per-partition merged shuffle files.
+ */
+ val mergeStatuses = if (numReducers > 0) {
+ new Array[MergeStatus](numReducers)
+ } else {
+ Array.empty[MergeStatus]
+ }
+
+ /**
* The cached result of serializing the map statuses array. This cache is lazily populated when
* [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
*/
@@ -103,11 +119,23 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
private[spark] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
/**
+ * Similar to cachedSerializedMapStatus and cachedSerializedBroadcast, but for MergeStatus.
+ */
+ private[this] var cachedSerializedMergeStatus: Array[Byte] = _
+
+ private[this] var cachedSerializedBroadcastMergeStatus: Broadcast[Array[Byte]] = _
+
+ /**
* Counter tracking the number of partitions that have output. This is a performance optimization
* to avoid having to count the number of non-null entries in the `mapStatuses` array and should
* be equivalent to`mapStatuses.count(_ ne null)`.
*/
- private[this] var _numAvailableOutputs: Int = 0
+ private[this] var _numAvailableMapOutputs: Int = 0
+
+ /**
+ * Counter tracking the number of MergeStatus results received so far from the shuffle services.
+ */
+ private[this] var _numAvailableMergeResults: Int = 0
/**
* Register a map output. If there is already a registered location for the map output then it
@@ -115,7 +143,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
*/
def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock {
if (mapStatuses(mapIndex) == null) {
- _numAvailableOutputs += 1
+ _numAvailableMapOutputs += 1
invalidateSerializedMapOutputStatusCache()
}
mapStatuses(mapIndex) = status
@@ -149,13 +177,37 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
def removeMapOutput(mapIndex: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
logDebug(s"Removing existing map output ${mapIndex} ${bmAddress}")
if (mapStatuses(mapIndex) != null && mapStatuses(mapIndex).location == bmAddress) {
- _numAvailableOutputs -= 1
+ _numAvailableMapOutputs -= 1
mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
}
/**
+ * Register a merge result.
+ */
+ def addMergeResult(reduceId: Int, status: MergeStatus): Unit = withWriteLock {
+ if (mergeStatuses(reduceId) != status) {
+ _numAvailableMergeResults += 1
+ invalidateSerializedMergeOutputStatusCache()
+ }
+ mergeStatuses(reduceId) = status
+ }
+
+ // TODO support updateMergeResult for similar use cases as updateMapOutput
+
+ /**
+ * Remove the merge result which was served by the specified block manager.
+ */
+ def removeMergeResult(reduceId: Int, bmAddress: BlockManagerId): Unit = withWriteLock {
+ if (mergeStatuses(reduceId) != null && mergeStatuses(reduceId).location == bmAddress) {
+ _numAvailableMergeResults -= 1
+ mergeStatuses(reduceId) = null
+ invalidateSerializedMergeOutputStatusCache()
+ }
+ }
+
+ /**
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
@@ -181,18 +233,33 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
def removeOutputsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock {
for (mapIndex <- mapStatuses.indices) {
if (mapStatuses(mapIndex) != null && f(mapStatuses(mapIndex).location)) {
- _numAvailableOutputs -= 1
+ _numAvailableMapOutputs -= 1
mapStatuses(mapIndex) = null
invalidateSerializedMapOutputStatusCache()
}
}
+ for (reduceId <- mergeStatuses.indices) {
+ if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) {
+ _numAvailableMergeResults -= 1
+ mergeStatuses(reduceId) = null
+ invalidateSerializedMergeOutputStatusCache()
+ }
+ }
+ }
+
+ /**
+ * Number of partitions that have shuffle map outputs.
+ */
+ def numAvailableMapOutputs: Int = withReadLock {
+ _numAvailableMapOutputs
}
/**
- * Number of partitions that have shuffle outputs.
+ * Number of shuffle partitions that have already been merge finalized when push-based
+ * is enabled.
*/
- def numAvailableOutputs: Int = withReadLock {
- _numAvailableOutputs
+ def numAvailableMergeResults: Int = withReadLock {
+ _numAvailableMergeResults
}
/**
@@ -200,19 +267,19 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
*/
def findMissingPartitions(): Seq[Int] = withReadLock {
val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
- assert(missing.size == numPartitions - _numAvailableOutputs,
- s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
+ assert(missing.size == numPartitions - _numAvailableMapOutputs,
+ s"${missing.size} missing, expected ${numPartitions - _numAvailableMapOutputs}")
missing
}
/**
* Serializes the mapStatuses array into an efficient compressed format. See the comments on
- * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format.
+ * `MapOutputTracker.serializeOutputStatuses()` for more details on the serialization format.
*
* This method is designed to be called multiple times and implements caching in order to speed
* up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
- * serialize the map statuses then serialization will only be performed in a single thread and all
- * other threads will block until the cache is populated.
+ * serialize the map statuses then serialization will only be performed in a single thread and
+ * all other threads will block until the cache is populated.
*/
def serializedMapStatus(
broadcastManager: BroadcastManager,
@@ -220,7 +287,6 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
minBroadcastSize: Int,
conf: SparkConf): Array[Byte] = {
var result: Array[Byte] = null
-
withReadLock {
if (cachedSerializedMapStatus != null) {
result = cachedSerializedMapStatus
@@ -229,7 +295,7 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
if (result == null) withWriteLock {
if (cachedSerializedMapStatus == null) {
- val serResult = MapOutputTracker.serializeMapStatuses(
+ val serResult = MapOutputTracker.serializeOutputStatuses[MapStatus](
mapStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
cachedSerializedMapStatus = serResult._1
cachedSerializedBroadcast = serResult._2
@@ -241,6 +307,47 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
result
}
+ /**
+ * Serializes the mapStatuses and mergeStatuses array into an efficient compressed format.
+ * See the comments on `MapOutputTracker.serializeOutputStatuses()` for more details
+ * on the serialization format.
+ *
+ * This method is designed to be called multiple times and implements caching in order to speed
+ * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
+ * serialize the statuses array then serialization will only be performed in a single thread and
+ * all other threads will block until the cache is populated.
+ */
+ def serializedMapAndMergeStatus(
+ broadcastManager: BroadcastManager,
+ isLocal: Boolean,
+ minBroadcastSize: Int,
+ conf: SparkConf): (Array[Byte], Array[Byte]) = {
+ val mapStatusesBytes: Array[Byte] =
+ serializedMapStatus(broadcastManager, isLocal, minBroadcastSize, conf)
+ var mergeStatusesBytes: Array[Byte] = null
+
+ withReadLock {
+ if (cachedSerializedMergeStatus != null) {
+ mergeStatusesBytes = cachedSerializedMergeStatus
+ }
+ }
+
+ if (mergeStatusesBytes == null) withWriteLock {
+ if (cachedSerializedMergeStatus == null) {
+ val serResult = MapOutputTracker.serializeOutputStatuses[MergeStatus](
+ mergeStatuses, broadcastManager, isLocal, minBroadcastSize, conf)
+ cachedSerializedMergeStatus = serResult._1
+ cachedSerializedBroadcastMergeStatus = serResult._2
+ }
+
+ // The following line has to be outside if statement since it's possible that another
+ // thread initializes cachedSerializedMergeStatus in-between `withReadLock` and
+ // `withWriteLock`.
+ mergeStatusesBytes = cachedSerializedMergeStatus
+ }
+ (mapStatusesBytes, mergeStatusesBytes)
+ }
+
// Used in testing.
def hasCachedSerializedBroadcast: Boolean = withReadLock {
cachedSerializedBroadcast != null
@@ -254,6 +361,10 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
f(mapStatuses)
}
+ def withMergeStatuses[T](f: Array[MergeStatus] => T): T = withReadLock {
+ f(mergeStatuses)
+ }
+
/**
* Clears the cached serialized map output statuses.
*/
@@ -269,14 +380,35 @@ private class ShuffleStatus(numPartitions: Int) extends Logging {
}
cachedSerializedMapStatus = null
}
+
+ /**
+ * Clears the cached serialized merge result statuses.
+ */
+ def invalidateSerializedMergeOutputStatusCache(): Unit = withWriteLock {
+ if (cachedSerializedBroadcastMergeStatus != null) {
+ Utils.tryLogNonFatalError {
+ // Use `blocking = false` so that this operation doesn't hang while trying to send cleanup
+ // RPCs to dead executors.
+ cachedSerializedBroadcastMergeStatus.destroy()
+ }
+ cachedSerializedBroadcastMergeStatus = null
+ }
+ cachedSerializedMergeStatus = null
+ }
}
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int)
extends MapOutputTrackerMessage
+private[spark] case class GetMapAndMergeResultStatuses(shuffleId: Int)
+ extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
-private[spark] case class GetMapOutputMessage(shuffleId: Int, context: RpcCallContext)
+private[spark] sealed trait MapOutputTrackerMasterMessage
+private[spark] case class GetMapOutputMessage(shuffleId: Int,
+ context: RpcCallContext) extends MapOutputTrackerMasterMessage
+private[spark] case class GetMapAndMergeOutputMessage(shuffleId: Int,
+ context: RpcCallContext) extends MapOutputTrackerMasterMessage
/** RpcEndpoint class for MapOutputTrackerMaster */
private[spark] class MapOutputTrackerMasterEndpoint(
@@ -288,8 +420,13 @@ private[spark] class MapOutputTrackerMasterEndpoint(
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case GetMapOutputStatuses(shuffleId: Int) =>
val hostPort = context.senderAddress.hostPort
- logInfo(s"Asked to send map output locations for shuffle ${shuffleId} to ${hostPort}")
- tracker.post(new GetMapOutputMessage(shuffleId, context))
+ logInfo(s"Asked to send map output locations for shuffle $shuffleId to $hostPort")
+ tracker.post(GetMapOutputMessage(shuffleId, context))
+
+ case GetMapAndMergeResultStatuses(shuffleId: Int) =>
+ val hostPort = context.senderAddress.hostPort
+ logInfo(s"Asked to send map/merge result locations for shuffle $shuffleId to $hostPort")
+ tracker.post(GetMapAndMergeOutputMessage(shuffleId, context))
case StopMapOutputTracker =>
logInfo("MapOutputTrackerMasterEndpoint stopped!")
@@ -368,6 +505,40 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
/**
+ * Called from executors upon fetch failure on an entire merged shuffle reduce partition.
+ * Such failures can happen if the shuffle client fails to fetch the metadata for the given
+ * merged shuffle partition. This method is to get the server URIs and output sizes for each
+ * shuffle block that is merged in the specified merged shuffle block so fetch failure on a
+ * merged shuffle block can fall back to fetching the unmerged blocks.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
+ * tuples describing the shuffle blocks that are stored at that block manager.
+ */
+ def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+
+ /**
+ * Called from executors upon fetch failure on a merged shuffle reduce partition chunk. This is
+ * to get the server URIs and output sizes for each shuffle block that is merged in the specified
+ * merged shuffle partition chunk so fetch failure on a merged shuffle block chunk can fall back
+ * to fetching the unmerged blocks.
+ *
+ * chunkBitMap tracks the mapIds which are part of the current merged chunk, this way if there is
+ * a fetch failure on the merged chunk, it can fallback to fetching the corresponding original
+ * blocks part of this merged chunk.
+ *
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block ID, shuffle block size, map index)
+ * tuples describing the shuffle blocks that are stored at that block manager.
+ */
+ def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int,
+ chunkBitmap: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]
+
+ /**
* Deletes map output status information for the specified shuffle stage.
*/
def unregisterShuffle(shuffleId: Int): Unit
@@ -415,8 +586,11 @@ private[spark] class MapOutputTrackerMaster(
private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
- // requests for map output statuses
- private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
+ // requests for MapOutputTrackerMasterMessages
+ private val mapOutputTrackerMasterMessages =
+ new LinkedBlockingQueue[MapOutputTrackerMasterMessage]
+
+ private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf)
// Thread pool used for handling map output status requests. This is a separate thread pool
// to ensure we don't block the normal dispatcher threads.
@@ -439,31 +613,47 @@ private[spark] class MapOutputTrackerMaster(
throw new IllegalArgumentException(msg)
}
- def post(message: GetMapOutputMessage): Unit = {
- mapOutputRequests.offer(message)
+ def post(message: MapOutputTrackerMasterMessage): Unit = {
+ mapOutputTrackerMasterMessages.offer(message)
}
/** Message loop used for dispatching messages. */
private class MessageLoop extends Runnable {
+ private def handleStatusMessage(
+ shuffleId: Int,
+ context: RpcCallContext,
+ needMergeOutput: Boolean): Unit = {
+ val hostPort = context.senderAddress.hostPort
+ val shuffleStatus = shuffleStatuses.get(shuffleId).head
+ logDebug(s"Handling request to send ${if (needMergeOutput) "map" else "map/merge"}" +
+ s" output locations for shuffle $shuffleId to $hostPort")
+ if (needMergeOutput) {
+ context.reply(
+ shuffleStatus.
+ serializedMapAndMergeStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
+ } else {
+ context.reply(
+ shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast, conf))
+ }
+ }
+
override def run(): Unit = {
try {
while (true) {
try {
- val data = mapOutputRequests.take()
- if (data == PoisonPill) {
+ val data = mapOutputTrackerMasterMessages.take()
+ if (data == PoisonPill) {
// Put PoisonPill back so that other MessageLoops can see it.
- mapOutputRequests.offer(PoisonPill)
+ mapOutputTrackerMasterMessages.offer(PoisonPill)
return
}
- val context = data.context
- val shuffleId = data.shuffleId
- val hostPort = context.senderAddress.hostPort
- logDebug("Handling request to send map output locations for shuffle " + shuffleId +
- " to " + hostPort)
- val shuffleStatus = shuffleStatuses.get(shuffleId).head
- context.reply(
- shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast,
- conf))
+
+ data match {
+ case GetMapOutputMessage(shuffleId, context) =>
+ handleStatusMessage(shuffleId, context, false)
+ case GetMapAndMergeOutputMessage(shuffleId, context) =>
+ handleStatusMessage(shuffleId, context, true)
+ }
} catch {
case NonFatal(e) => logError(e.getMessage, e)
}
@@ -475,16 +665,22 @@ private[spark] class MapOutputTrackerMaster(
}
/** A poison endpoint that indicates MessageLoop should exit its message loop. */
- private val PoisonPill = new GetMapOutputMessage(-99, null)
+ private val PoisonPill = GetMapOutputMessage(-99, null)
// Used only in unit tests.
private[spark] def getNumCachedSerializedBroadcast: Int = {
shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
}
- def registerShuffle(shuffleId: Int, numMaps: Int): Unit = {
- if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
- throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = {
+ if (pushBasedShuffleEnabled) {
+ if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
+ } else {
+ if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
+ throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
+ }
}
}
@@ -524,10 +720,49 @@ private[spark] class MapOutputTrackerMaster(
}
}
+ def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) {
+ shuffleStatuses(shuffleId).addMergeResult(reduceId, status)
+ }
+
+ def registerMergeResults(shuffleId: Int, statuses: Seq[(Int, MergeStatus)]): Unit = {
+ statuses.foreach {
+ case (reduceId, status) => registerMergeResult(shuffleId, reduceId, status)
+ }
+ }
+
+ /**
+ * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId
+ * is specified, it will only unregister the merge result if the mapId is part of that merge
+ * result.
+ *
+ * @param shuffleId the shuffleId.
+ * @param reduceId the reduceId.
+ * @param bmAddress block manager address.
+ * @param mapId the optional mapId which should be checked to see it was part of the merge
+ * result.
+ */
+ def unregisterMergeResult(
+ shuffleId: Int,
+ reduceId: Int,
+ bmAddress: BlockManagerId,
+ mapId: Option[Int] = None) {
+ shuffleStatuses.get(shuffleId) match {
+ case Some(shuffleStatus) =>
+ val mergeStatus = shuffleStatus.mergeStatuses(reduceId)
+ if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) {
+ shuffleStatus.removeMergeResult(reduceId, bmAddress)
+ incrementEpoch()
+ }
+ case None =>
+ throw new SparkException("unregisterMergeResult called for nonexistent shuffle ID")
+ }
+ }
+
/** Unregister shuffle data */
def unregisterShuffle(shuffleId: Int): Unit = {
shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
shuffleStatus.invalidateSerializedMapOutputStatusCache()
+ shuffleStatus.invalidateSerializedMergeOutputStatusCache()
}
}
@@ -554,7 +789,12 @@ private[spark] class MapOutputTrackerMaster(
def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
def getNumAvailableOutputs(shuffleId: Int): Int = {
- shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+ shuffleStatuses.get(shuffleId).map(_.numAvailableMapOutputs).getOrElse(0)
+ }
+
+ /** VisibleForTest. Invoked in test only. */
+ private[spark] def getNumAvailableMergeResults(shuffleId: Int): Int = {
+ shuffleStatuses.get(shuffleId).map(_.numAvailableMergeResults).getOrElse(0)
}
/**
@@ -633,7 +873,9 @@ private[spark] class MapOutputTrackerMaster(
/**
* Return the preferred hosts on which to run the given map output partition in a given shuffle,
- * i.e. the nodes that the most outputs for that partition are on.
+ * i.e. the nodes that the most outputs for that partition are on. If the map output is
+ * pre-merged, then return the node where the merged block is located if the merge ratio is
+ * above the threshold.
*
* @param dep shuffle dependency object
* @param partitionId map output partition that we want to read
@@ -641,15 +883,40 @@ private[spark] class MapOutputTrackerMaster(
*/
def getPreferredLocationsForShuffle(dep: ShuffleDependency[_, _, _], partitionId: Int)
: Seq[String] = {
- if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
- dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
- val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
- dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
- if (blockManagerIds.nonEmpty) {
- blockManagerIds.get.map(_.host)
+ val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull
+ if (shuffleStatus != null) {
+ // Check if the map output is pre-merged and if the merge ratio is above the threshold.
+ // If so, the location of the merged block is the preferred location.
+ val preferredLoc = if (pushBasedShuffleEnabled) {
+ shuffleStatus.withMergeStatuses { statuses =>
+ val status = statuses(partitionId)
+ val numMaps = dep.rdd.partitions.length
+ if (status != null && status.getNumMissingMapOutputs(numMaps).toDouble / numMaps
+ <= (1 - REDUCER_PREF_LOCS_FRACTION)) {
+ Seq(status.location.host)
+ } else {
+ Nil
+ }
+ }
} else {
Nil
}
+ if (preferredLoc.nonEmpty) {
+ preferredLoc
+ } else {
+ if (shuffleLocalityEnabled && dep.rdd.partitions.length < SHUFFLE_PREF_MAP_THRESHOLD &&
+ dep.partitioner.numPartitions < SHUFFLE_PREF_REDUCE_THRESHOLD) {
+ val blockManagerIds = getLocationsWithLargestOutputs(dep.shuffleId, partitionId,
+ dep.partitioner.numPartitions, REDUCER_PREF_LOCS_FRACTION)
+ if (blockManagerIds.nonEmpty) {
+ blockManagerIds.get.map(_.host)
+ } else {
+ Nil
+ }
+ } else {
+ Nil
+ }
+ }
} else {
Nil
}
@@ -774,8 +1041,25 @@ private[spark] class MapOutputTrackerMaster(
}
}
+ // This method is only called in local-mode. Since push based shuffle won't be
+ // enabled in local-mode, this method returns empty list.
+ override def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ Seq.empty.toIterator
+ }
+
+ // This method is only called in local-mode. Since push based shuffle won't be
+ // enabled in local-mode, this method returns empty list.
+ override def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int,
+ chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ Seq.empty.toIterator
+ }
+
override def stop(): Unit = {
- mapOutputRequests.offer(PoisonPill)
+ mapOutputTrackerMasterMessages.offer(PoisonPill)
threadpool.shutdown()
try {
sendTracker(StopMapOutputTracker)
@@ -799,6 +1083,11 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
val mapStatuses: Map[Int, Array[MapStatus]] =
new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
+ val mergeStatuses: Map[Int, Array[MergeStatus]] =
+ new ConcurrentHashMap[Int, Array[MergeStatus]]().asScala
+
+ private val fetchMergeResult = Utils.isPushBasedShuffleEnabled(conf)
+
/**
* A [[KeyLock]] whose key is a shuffle id to ensure there is only one thread fetching
* the same shuffle block.
@@ -812,61 +1101,150 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId")
- val statuses = getStatuses(shuffleId, conf)
+ val (mapOutputStatuses, mergedOutputStatuses) = getStatuses(shuffleId, conf)
try {
- val actualEndMapIndex = if (endMapIndex == Int.MaxValue) statuses.length else endMapIndex
+ val actualEndMapIndex =
+ if (endMapIndex == Int.MaxValue) mapOutputStatuses.length else endMapIndex
logDebug(s"Convert map statuses for shuffle $shuffleId, " +
s"mappers $startMapIndex-$actualEndMapIndex, partitions $startPartition-$endPartition")
MapOutputTracker.convertMapStatuses(
- shuffleId, startPartition, endPartition, statuses, startMapIndex, actualEndMapIndex)
+ shuffleId, startPartition, endPartition, mapOutputStatuses, startMapIndex,
+ actualEndMapIndex, Option(mergedOutputStatuses))
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
mapStatuses.clear()
+ mergeStatuses.clear()
+ throw e
+ }
+ }
+
+ override def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+ // Fetch the map statuses and merge statuses again since they might have already been
+ // cleared by another task running in the same executor.
+ val (mapOutputStatuses, mergeResultStatuses) = getStatuses(shuffleId, conf)
+ try {
+ val mergeStatus = mergeResultStatuses(partitionId)
+ // If the original MergeStatus is no longer available, we cannot identify the list of
+ // unmerged blocks to fetch in this case. Throw MetadataFetchFailedException in this case.
+ MapOutputTracker.validateStatus(mergeStatus, shuffleId, partitionId)
+ // Use the MergeStatus's partition level bitmap since we are doing partition level fallback
+ MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId,
+ mapOutputStatuses, mergeStatus.tracker)
+ } catch {
+ // We experienced a fetch failure so our mapStatuses cache is outdated; clear it
+ case e: MetadataFetchFailedException =>
+ mapStatuses.clear()
+ mergeStatuses.clear()
+ throw e
+ }
+ }
+
+ override def getMapSizesForMergeResult(
+ shuffleId: Int,
+ partitionId: Int,
+ chunkTracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ logDebug(s"Fetching backup outputs for shuffle $shuffleId, partition $partitionId")
+ // Fetch the map statuses and merge statuses again since they might have already been
+ // cleared by another task running in the same executor.
+ val (mapOutputStatuses, _) = getStatuses(shuffleId, conf)
+ try {
+ MapOutputTracker.getMapStatusesForMergeStatus(shuffleId, partitionId, mapOutputStatuses,
+ chunkTracker)
+ } catch {
+ // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+ case e: MetadataFetchFailedException =>
+ mapStatuses.clear()
+ mergeStatuses.clear()
throw e
}
}
/**
- * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
+ * Get or fetch the array of MapStatuses and MergeStatuses if push based shuffle enabled
+ * for a given shuffle ID. NOTE: clients MUST synchronize
* on this array when reading it, because on the driver, we may be changing it in place.
*
* (It would be nice to remove this restriction in the future.)
*/
- private def getStatuses(shuffleId: Int, conf: SparkConf): Array[MapStatus] = {
- val statuses = mapStatuses.get(shuffleId).orNull
- if (statuses == null) {
- logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
- val startTimeNs = System.nanoTime()
- fetchingLock.withLock(shuffleId) {
- var fetchedStatuses = mapStatuses.get(shuffleId).orNull
- if (fetchedStatuses == null) {
- logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
- val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
- try {
- fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
- } catch {
- case e: SparkException =>
- throw new MetadataFetchFailedException(shuffleId, -1,
- s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
- e.getCause)
+ private def getStatuses(
+ shuffleId: Int,
+ conf: SparkConf): (Array[MapStatus], Array[MergeStatus]) = {
+ if (fetchMergeResult) {
+ val mapOutputStatuses = mapStatuses.get(shuffleId).orNull
+ val mergeOutputStatuses = mergeStatuses.get(shuffleId).orNull
+
+ if (mapOutputStatuses == null || mergeOutputStatuses == null) {
+ logInfo("Don't have map/merge outputs for shuffle " + shuffleId + ", fetching them")
+ val startTimeNs = System.nanoTime()
+ fetchingLock.withLock(shuffleId) {
+ var fetchedMapStatuses = mapStatuses.get(shuffleId).orNull
+ var fetchedMergeStatuses = mergeStatuses.get(shuffleId).orNull
+ if (fetchedMapStatuses == null || fetchedMergeStatuses == null) {
+ logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+ val fetchedBytes =
+ askTracker[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(shuffleId))
+ try {
+ fetchedMapStatuses =
+ MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, conf)
+ fetchedMergeStatuses =
+ MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, conf)
+ } catch {
+ case e: SparkException =>
+ throw new MetadataFetchFailedException(shuffleId, -1,
+ s"Unable to deserialize broadcasted map/merge statuses" +
+ s" for shuffle $shuffleId: " + e.getCause)
+ }
+ logInfo("Got the map/merge output locations")
+ mapStatuses.put(shuffleId, fetchedMapStatuses)
+ mergeStatuses.put(shuffleId, fetchedMergeStatuses)
}
- logInfo("Got the output locations")
- mapStatuses.put(shuffleId, fetchedStatuses)
+ logDebug(s"Fetching map/merge output statuses for shuffle $shuffleId took " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+ (fetchedMapStatuses, fetchedMergeStatuses)
}
- logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
- s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
- fetchedStatuses
+ } else {
+ (mapOutputStatuses, mergeOutputStatuses)
}
} else {
- statuses
+ val statuses = mapStatuses.get(shuffleId).orNull
+ if (statuses == null) {
+ logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+ val startTimeNs = System.nanoTime()
+ fetchingLock.withLock(shuffleId) {
+ var fetchedStatuses = mapStatuses.get(shuffleId).orNull
+ if (fetchedStatuses == null) {
+ logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+ val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+ try {
+ fetchedStatuses =
+ MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf)
+ } catch {
+ case e: SparkException =>
+ throw new MetadataFetchFailedException(shuffleId, -1,
+ s"Unable to deserialize broadcasted map statuses for shuffle $shuffleId: " +
+ e.getCause)
+ }
+ logInfo("Got the map output locations")
+ mapStatuses.put(shuffleId, fetchedStatuses)
+ }
+ logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+ s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+ (fetchedStatuses, null)
+ }
+ } else {
+ (statuses, null)
+ }
}
}
-
/** Unregister shuffle data. */
def unregisterShuffle(shuffleId: Int): Unit = {
mapStatuses.remove(shuffleId)
+ mergeStatuses.remove(shuffleId)
}
/**
@@ -880,6 +1258,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
epoch = newEpoch
mapStatuses.clear()
+ mergeStatuses.clear()
}
}
}
@@ -891,11 +1270,13 @@ private[spark] object MapOutputTracker extends Logging {
private val DIRECT = 0
private val BROADCAST = 1
- // Serialize an array of map output locations into an efficient byte format so that we can send
- // it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
- // generally be pretty compressible because many map outputs will be on the same hostname.
- def serializeMapStatuses(
- statuses: Array[MapStatus],
+ private val SHUFFLE_PUSH_MAP_ID = -1
+
+ // Serialize an array of map/merge output locations into an efficient byte format so that we can
+ // send it to reduce tasks. We do this by compressing the serialized bytes using Zstd. They will
+ // generally be pretty compressible because many outputs will be on the same hostname.
+ def serializeOutputStatuses[T <: ShuffleOutputStatus](
+ statuses: Array[T],
broadcastManager: BroadcastManager,
isLocal: Boolean,
minBroadcastSize: Int,
@@ -931,15 +1312,16 @@ private[spark] object MapOutputTracker extends Logging {
oos.close()
}
val outArr = out.toByteArray
- logInfo("Broadcast mapstatuses size = " + outArr.length + ", actual size = " + arr.length)
+ logInfo("Broadcast outputstatuses size = " + outArr.length + ", actual size = " + arr.length)
(outArr, bcast)
} else {
(arr, null)
}
}
- // Opposite of serializeMapStatuses.
- def deserializeMapStatuses(bytes: Array[Byte], conf: SparkConf): Array[MapStatus] = {
+ // Opposite of serializeOutputStatuses.
+ def deserializeOutputStatuses[T <: ShuffleOutputStatus](
+ bytes: Array[Byte], conf: SparkConf): Array[T] = {
assert (bytes.length > 0)
def deserializeObject(arr: Array[Byte], off: Int, len: Int): AnyRef = {
@@ -958,20 +1340,22 @@ private[spark] object MapOutputTracker extends Logging {
bytes(0) match {
case DIRECT =>
- deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[MapStatus]]
+ deserializeObject(bytes, 1, bytes.length - 1).asInstanceOf[Array[T]]
case BROADCAST =>
try {
// deserialize the Broadcast, pull .value array out of it, and then deserialize that
val bcast = deserializeObject(bytes, 1, bytes.length - 1).
asInstanceOf[Broadcast[Array[Byte]]]
- logInfo("Broadcast mapstatuses size = " + bytes.length +
+ logInfo("Broadcast outputstatuses size = " + bytes.length +
", actual size = " + bcast.value.length)
// Important - ignore the DIRECT tag ! Start from offset 1
- deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[MapStatus]]
+ deserializeObject(bcast.value, 1, bcast.value.length - 1).asInstanceOf[Array[T]]
} catch {
case e: IOException =>
- logWarning("Exception encountered during deserializing broadcasted map statuses: ", e)
- throw new SparkException("Unable to deserialize broadcasted map statuses", e)
+ logWarning("Exception encountered during deserializing broadcasted" +
+ " output statuses: ", e)
+ throw new SparkException("Unable to deserialize broadcasted" +
+ " output statuses", e)
}
case _ => throw new IllegalArgumentException("Unexpected byte tag = " + bytes(0))
}
@@ -983,15 +1367,19 @@ private[spark] object MapOutputTracker extends Logging {
* stored at that block manager.
* Note that empty blocks are filtered in the result.
*
+ * If push-based shuffle is enabled and an array of merge statuses is available, prioritize
+ * the locations of the merged shuffle partitions over unmerged shuffle blocks.
+ *
* If any of the statuses is null (indicating a missing location due to a failed mapper),
* throws a FetchFailedException.
*
* @param shuffleId Identifier for the shuffle
* @param startPartition Start of map output partition ID range (included in range)
* @param endPartition End of map output partition ID range (excluded from range)
- * @param statuses List of map statuses, indexed by map partition index.
+ * @param mapStatuses List of map statuses, indexed by map partition index.
* @param startMapIndex Start Map index.
* @param endMapIndex End Map index.
+ * @param mergeStatuses List of merge statuses, index by reduce ID.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size, map index)
* tuples describing the shuffle blocks that are stored at that block manager.
@@ -1000,18 +1388,57 @@ private[spark] object MapOutputTracker extends Logging {
shuffleId: Int,
startPartition: Int,
endPartition: Int,
- statuses: Array[MapStatus],
+ mapStatuses: Array[MapStatus],
startMapIndex : Int,
- endMapIndex: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
- assert (statuses != null)
+ endMapIndex: Int,
+ mergeStatuses: Option[Array[MergeStatus]] = None):
+ Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ assert (mapStatuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
- val iter = statuses.iterator.zipWithIndex
- for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
- if (status == null) {
- val errorMessage = s"Missing an output location for shuffle $shuffleId"
- logError(errorMessage)
- throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
- } else {
+ // Only use MergeStatus for reduce tasks that fetch all map outputs. Since a merged shuffle
+ // partition consists of blocks merged in random order, we are unable to serve map index
+ // subrange requests. However, when a reduce task needs to fetch blocks from a subrange of
+ // map outputs, it usually indicates skewed partitions which push-based shuffle delegates
+ // to AQE to handle.
+ // TODO: SPARK-35036: Instead of reading map blocks in case of AQE with Push based shuffle,
+ // TODO: improve push based shuffle to read partial merged blocks satisfying the start/end
+ // TODO: map indexes
+ if (mergeStatuses.exists(_.nonEmpty) && startMapIndex == 0
+ && endMapIndex == mapStatuses.length) {
+ // We have MergeStatus and full range of mapIds are requested so return a merged block.
+ val numMaps = mapStatuses.length
+ mergeStatuses.get.zipWithIndex.slice(startPartition, endPartition).foreach {
+ case (mergeStatus, partId) =>
+ val remainingMapStatuses = if (mergeStatus != null && mergeStatus.totalSize > 0) {
+ // If MergeStatus is available for the given partition, add location of the
+ // pre-merged shuffle partition for this partition ID. Here we create a
+ // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is
+ // a merged shuffle block.
+ splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1))
+ // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper
+ // shuffle partition blocks, fetch the original map produced shuffle partition blocks
+ val mapStatusesWithIndex = mapStatuses.zipWithIndex
+ mergeStatus.getMissingMaps(numMaps).map(mapStatusesWithIndex)
+ } else {
+ // If MergeStatus is not available for the given partition, fall back to
+ // fetching all the original mapper shuffle partition blocks
+ mapStatuses.zipWithIndex.toSeq
+ }
+ // Add location for the mapper shuffle partition blocks
+ for ((mapStatus, mapIndex) <- remainingMapStatuses) {
+ validateStatus(mapStatus, shuffleId, partId)
+ val size = mapStatus.getSizeForBlock(partId)
+ if (size != 0) {
+ splitsByAddress.getOrElseUpdate(mapStatus.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, mapStatus.mapId, partId), size, mapIndex))
+ }
+ }
+ }
+ } else {
+ val iter = mapStatuses.iterator.zipWithIndex
+ for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
+ validateStatus(status, shuffleId, startPartition)
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
if (size != 0) {
@@ -1024,4 +1451,47 @@ private[spark] object MapOutputTracker extends Logging {
splitsByAddress.mapValues(_.toSeq).iterator
}
+
+ /**
+ * Given a shuffle ID, a partition ID, an array of map statuses, and bitmap corresponding
+ * to either a merged shuffle partition or a merged shuffle partition chunk, identify
+ * the metadata about the shuffle partition blocks that are merged into the merged shuffle
+ * partition or partition chunk represented by the bitmap.
+ *
+ * @param shuffleId Identifier for the shuffle
+ * @param partitionId The partition ID of the MergeStatus for which we look for the metadata
+ * of the merged shuffle partition blocks
+ * @param mapStatuses List of map statuses, indexed by map ID
+ * @param tracker bitmap containing mapIndexes that belong to the merged block or merged
+ * block chunk.
+ * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
+ * and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
+ * describing the shuffle blocks that are stored at that block manager.
+ */
+ def getMapStatusesForMergeStatus(
+ shuffleId: Int,
+ partitionId: Int,
+ mapStatuses: Array[MapStatus],
+ tracker: RoaringBitmap): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
+ assert (mapStatuses != null && tracker != null)
+ val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
+ for ((status, mapIndex) <- mapStatuses.zipWithIndex) {
+ // Only add blocks that are merged
+ if (tracker.contains(mapIndex)) {
+ MapOutputTracker.validateStatus(status, shuffleId, partitionId)
+ splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
+ ((ShuffleBlockId(shuffleId, status.mapId, partitionId),
+ status.getSizeForBlock(partitionId), mapIndex))
+ }
+ }
+ splitsByAddress.mapValues(_.toSeq).iterator
+ }
+
+ def validateStatus(status: ShuffleOutputStatus, shuffleId: Int, partition: Int) : Unit = {
+ if (status == null) {
+ val errorMessage = s"Missing an output location for shuffle $shuffleId partition $partition"
+ logError(errorMessage)
+ throw new MetadataFetchFailedException(shuffleId, partition, errorMessage)
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index c2e7c4d..a92d9fa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -456,7 +456,8 @@ private[spark] class DAGScheduler(
// since we can't do it in the RDD constructor because # of partitions is unknown
logInfo(s"Registering RDD ${rdd.id} (${rdd.getCreationSite}) as input to " +
s"shuffle ${shuffleDep.shuffleId}")
- mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length)
+ mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.length,
+ shuffleDep.partitioner.numPartitions)
}
stage
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 1239c32..07eed76 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -29,11 +29,17 @@ import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.Utils
/**
+ * A common trait between [[MapStatus]] and [[MergeStatus]]. This allows us to reuse existing
+ * code to handle MergeStatus inside MapOutputTracker.
+ */
+private[spark] trait ShuffleOutputStatus
+
+/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
* task has shuffle files stored on as well as the sizes of outputs for each reducer, for passing
* on to the reduce tasks.
*/
-private[spark] sealed trait MapStatus {
+private[spark] sealed trait MapStatus extends ShuffleOutputStatus {
/** Location where this task output is. */
def location: BlockManagerId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
new file mode 100644
index 0000000..77d8f8e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/MergeStatus.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.io.{Externalizable, ObjectInput, ObjectOutput}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.network.shuffle.protocol.MergeStatuses
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
+
+/**
+ * The status for the result of merging shuffle partition blocks per individual shuffle partition
+ * maintained by the scheduler. The scheduler would separate the
+ * [[org.apache.spark.network.shuffle.protocol.MergeStatuses]] received from
+ * ExternalShuffleService into individual [[MergeStatus]] which is maintained inside
+ * MapOutputTracker to be served to the reducers when they start fetching shuffle partition
+ * blocks. Note that, the reducers are ultimately fetching individual chunks inside a merged
+ * shuffle file, as explained in [[org.apache.spark.network.shuffle.RemoteBlockPushResolver]].
+ * Between the scheduler maintained MergeStatus and the shuffle service maintained per shuffle
+ * partition meta file, we are effectively dividing the metadata for a push-based shuffle into
+ * 2 layers. The scheduler would track the top-level metadata at the shuffle partition level
+ * with MergeStatus, and the shuffle service would maintain the partition level metadata about
+ * how to further divide a merged shuffle partition into multiple chunks with the per-partition
+ * meta file. This helps to reduce the amount of data the scheduler needs to maintain for
+ * push-based shuffle.
+ */
+private[spark] class MergeStatus(
+ private[this] var loc: BlockManagerId,
+ private[this] var mapTracker: RoaringBitmap,
+ private[this] var size: Long)
+ extends Externalizable with ShuffleOutputStatus {
+
+ protected def this() = this(null, null, -1) // For deserialization only
+
+ def location: BlockManagerId = loc
+
+ def totalSize: Long = size
+
+ def tracker: RoaringBitmap = mapTracker
+
+ /**
+ * Get the list of mapper IDs for missing mapper partition blocks that are not merged.
+ * The reducer will use this information to decide which shuffle partition blocks to
+ * fetch in the original way.
+ */
+ def getMissingMaps(numMaps: Int): Seq[Int] = {
+ (0 until numMaps).filter(i => !mapTracker.contains(i))
+ }
+
+ /**
+ * Get the number of missing map outputs for missing mapper partition blocks that are not merged.
+ */
+ def getNumMissingMapOutputs(numMaps: Int): Int = {
+ (0 until numMaps).count(i => !mapTracker.contains(i))
+ }
+
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
+ loc.writeExternal(out)
+ mapTracker.writeExternal(out)
+ out.writeLong(size)
+ }
+
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
+ loc = BlockManagerId(in)
+ mapTracker = new RoaringBitmap()
+ mapTracker.readExternal(in)
+ size = in.readLong()
+ }
+}
+
+private[spark] object MergeStatus {
+ // Dummy number of reduces for the tests where push based shuffle is not enabled
+ val SHUFFLE_PUSH_DUMMY_NUM_REDUCES = 1
+
+ /**
+ * Separate a MergeStatuses received from an ExternalShuffleService into individual
+ * MergeStatus. The scheduler is responsible for providing the location information
+ * for the given ExternalShuffleService.
+ */
+ def convertMergeStatusesToMergeStatusArr(
+ mergeStatuses: MergeStatuses,
+ loc: BlockManagerId): Seq[(Int, MergeStatus)] = {
+ assert(mergeStatuses.bitmaps.length == mergeStatuses.reduceIds.length &&
+ mergeStatuses.bitmaps.length == mergeStatuses.sizes.length)
+ val mergerLoc = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, loc.host, loc.port)
+ mergeStatuses.bitmaps.zipWithIndex.map {
+ case (bitmap, index) =>
+ val mergeStatus = new MergeStatus(mergerLoc, bitmap, mergeStatuses.sizes(index))
+ (mergeStatuses.reduceIds(index), mergeStatus)
+ }
+ }
+
+ def apply(loc: BlockManagerId, bitmap: RoaringBitmap, size: Long): MergeStatus = {
+ new MergeStatus(loc, bitmap, size)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 83fe450..f4b47e2 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -21,17 +21,19 @@ import scala.collection.mutable.ArrayBuffer
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
+import org.roaringbitmap.RoaringBitmap
import org.apache.spark.LocalSparkContext._
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.internal.config._
import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE}
+import org.apache.spark.internal.config.Tests.IS_TESTING
import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
-import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus}
+import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus, MergeStatus}
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
-class MapOutputTrackerSuite extends SparkFunSuite {
+class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
private val conf = new SparkConf
private def newTrackerMaster(sparkConf: SparkConf = conf) = {
@@ -58,7 +60,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
- tracker.registerShuffle(10, 2)
+ tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
assert(tracker.containsShuffle(10))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
@@ -82,7 +84,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
- tracker.registerShuffle(10, 2)
+ tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -105,7 +107,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
- tracker.registerShuffle(10, 2)
+ tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
val compressedSize1000 = MapStatus.compressSize(1000L)
val compressedSize10000 = MapStatus.compressSize(10000L)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -140,7 +142,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
mapWorkerTracker.trackerEndpoint =
mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
- masterTracker.registerShuffle(10, 1)
+ masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
mapWorkerTracker.updateEpoch(masterTracker.getEpoch)
// This is expected to fail because no outputs have been registered for the shuffle.
intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) }
@@ -183,7 +185,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
// Message size should be ~123B, and no exception should be thrown
- masterTracker.registerShuffle(10, 1)
+ masterTracker.registerShuffle(10, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0), 5))
val senderAddress = RpcAddress("localhost", 12345)
@@ -217,7 +219,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// on hostA with output size 2
// on hostA with output size 2
// on hostB with output size 3
- tracker.registerShuffle(10, 3)
+ tracker.registerShuffle(10, 3, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(2L), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
@@ -260,7 +262,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// Frame size should be ~1.1MB, and MapOutputTrackerMasterEndpoint should throw exception.
// Note that the size is hand-selected here because map output statuses are compressed before
// being sent.
- masterTracker.registerShuffle(20, 100)
+ masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
@@ -306,7 +308,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
- tracker.registerShuffle(10, 2)
+ tracker.registerShuffle(10, 2, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
@@ -332,6 +334,219 @@ class MapOutputTrackerSuite extends SparkFunSuite {
rpcEnv.shutdown()
}
+ test("SPARK-32921: master register and unregister merge result") {
+ conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(IS_TESTING, true)
+ val rpcEnv = createRpcEnv("test")
+ val tracker = newTrackerMaster()
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+ tracker.registerShuffle(10, 4, 2)
+ assert(tracker.containsShuffle(10))
+ val bitmap = new RoaringBitmap()
+ bitmap.add(0)
+ bitmap.add(1)
+
+ tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+ bitmap, 1000L))
+ tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+ bitmap, 1000L))
+ assert(tracker.getNumAvailableMergeResults(10) == 2)
+ tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
+ assert(tracker.getNumAvailableMergeResults(10) == 1)
+ tracker.stop()
+ rpcEnv.shutdown()
+ }
+
+ test("SPARK-32921: get map sizes with merged shuffle") {
+ conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(IS_TESTING, true)
+ val hostname = "localhost"
+ val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+ val masterTracker = newTrackerMaster()
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+ val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+ masterTracker.registerShuffle(10, 4, 1)
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ val bitmap = new RoaringBitmap()
+ bitmap.add(0)
+ bitmap.add(1)
+ bitmap.add(3)
+
+ val blockMgrId = BlockManagerId("a", "hostA", 1000)
+ masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+ masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+ masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+ masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+ masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+ bitmap, 3000L))
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq ===
+ Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, -1, 0), 3000, -1),
+ (ShuffleBlockId(10, 2, 0), size1000, 2)))))
+
+ masterTracker.stop()
+ slaveTracker.stop()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
+ }
+
+ test("SPARK-32921: get map statuses from merged shuffle") {
+ conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(IS_TESTING, true)
+ val hostname = "localhost"
+ val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+ val masterTracker = newTrackerMaster()
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+ val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+ masterTracker.registerShuffle(10, 4, 1)
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ // This is expected to fail because no outputs have been registered for the shuffle.
+ intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
+ val bitmap = new RoaringBitmap()
+ bitmap.add(0)
+ bitmap.add(1)
+ bitmap.add(2)
+ bitmap.add(3)
+
+ val blockMgrId = BlockManagerId("a", "hostA", 1000)
+ masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+ masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+ masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+ masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+ masterTracker.registerMergeResult(10, 0, MergeStatus(blockMgrId,
+ bitmap, 4000L))
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ assert(slaveTracker.getMapSizesForMergeResult(10, 0).toSeq ===
+ Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
+ (ShuffleBlockId(10, 1, 0), size1000, 1), (ShuffleBlockId(10, 2, 0), size1000, 2),
+ (ShuffleBlockId(10, 3, 0), size1000, 3)))))
+ masterTracker.stop()
+ slaveTracker.stop()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
+ }
+
+ test("SPARK-32921: get map statuses for merged shuffle block chunks") {
+ conf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ conf.set(IS_TESTING, true)
+ val hostname = "localhost"
+ val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))
+
+ val masterTracker = newTrackerMaster()
+ masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))
+
+ val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf))
+ val slaveTracker = new MapOutputTrackerWorker(conf)
+ slaveTracker.trackerEndpoint =
+ slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+ masterTracker.registerShuffle(10, 4, 1)
+ slaveTracker.updateEpoch(masterTracker.getEpoch)
+
+ val blockMgrId = BlockManagerId("a", "hostA", 1000)
+ masterTracker.registerMapOutput(10, 0, MapStatus(blockMgrId, Array(1000L), 0))
+ masterTracker.registerMapOutput(10, 1, MapStatus(blockMgrId, Array(1000L), 1))
+ masterTracker.registerMapOutput(10, 2, MapStatus(blockMgrId, Array(1000L), 2))
+ masterTracker.registerMapOutput(10, 3, MapStatus(blockMgrId, Array(1000L), 3))
+
+ val chunkBitmap = new RoaringBitmap()
+ chunkBitmap.add(0)
+ chunkBitmap.add(2)
+ val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
+ assert(slaveTracker.getMapSizesForMergeResult(10, 0, chunkBitmap).toSeq ===
+ Seq((blockMgrId, ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000, 0),
+ (ShuffleBlockId(10, 2, 0), size1000, 2))))
+ )
+ masterTracker.stop()
+ slaveTracker.stop()
+ rpcEnv.shutdown()
+ slaveRpcEnv.shutdown()
+ }
+
+ test("SPARK-32921: getPreferredLocationsForShuffle with MergeStatus") {
+ val rpcEnv = createRpcEnv("test")
+ val tracker = newTrackerMaster()
+ sc = new SparkContext("local", "test", conf.clone())
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+ // Setup 5 map tasks
+ // on hostA with output size 2
+ // on hostA with output size 2
+ // on hostB with output size 3
+ // on hostB with output size 3
+ // on hostC with output size 1
+ // on hostC with output size 1
+ tracker.registerShuffle(10, 6, 1)
+ tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
+ Array(2L), 5))
+ tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("a", "hostA", 1000),
+ Array(2L), 6))
+ tracker.registerMapOutput(10, 2, MapStatus(BlockManagerId("b", "hostB", 1000),
+ Array(3L), 7))
+ tracker.registerMapOutput(10, 3, MapStatus(BlockManagerId("b", "hostB", 1000),
+ Array(3L), 8))
+ tracker.registerMapOutput(10, 4, MapStatus(BlockManagerId("c", "hostC", 1000),
+ Array(1L), 9))
+ tracker.registerMapOutput(10, 5, MapStatus(BlockManagerId("c", "hostC", 1000),
+ Array(1L), 10))
+
+ val rdd = sc.parallelize(1 to 6, 6).map(num => (num, num).asInstanceOf[Product2[Int, Int]])
+ val mockShuffleDep = mock(classOf[ShuffleDependency[Int, Int, _]])
+ when(mockShuffleDep.shuffleId).thenReturn(10)
+ when(mockShuffleDep.partitioner).thenReturn(new HashPartitioner(1))
+ when(mockShuffleDep.rdd).thenReturn(rdd)
+
+ // Prepare a MergeStatus that merges 4 out of 5 blocks
+ val bitmap80 = new RoaringBitmap()
+ bitmap80.add(0)
+ bitmap80.add(1)
+ bitmap80.add(2)
+ bitmap80.add(3)
+ bitmap80.add(4)
+ tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+ bitmap80, 11))
+
+ val preferredLocs1 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
+ assert(preferredLocs1.nonEmpty)
+ assert(preferredLocs1.length === 1)
+ assert(preferredLocs1.head === "hostA")
+
+ tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000))
+ // Prepare another MergeStatus that merges only 1 out of 5 blocks
+ val bitmap20 = new RoaringBitmap()
+ bitmap20.add(0)
+ tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+ bitmap20, 2))
+
+ val preferredLocs2 = tracker.getPreferredLocationsForShuffle(mockShuffleDep, 0)
+ assert(preferredLocs2.nonEmpty)
+ assert(preferredLocs2.length === 2)
+ assert(preferredLocs2 === Seq("hostA", "hostB"))
+
+ tracker.stop()
+ rpcEnv.shutdown()
+ }
+
test("SPARK-34939: remote fetch using broadcast if broadcasted value is destroyed") {
val newConf = new SparkConf
newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
@@ -346,7 +561,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
rpcEnv.stop(masterTracker.trackerEndpoint)
rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
- masterTracker.registerShuffle(20, 100)
+ masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
(0 until 100).foreach { i =>
masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
@@ -368,9 +583,85 @@ class MapOutputTrackerSuite extends SparkFunSuite {
shuffleStatus.cachedSerializedBroadcast.destroy(true)
}
val err = intercept[SparkException] {
- MapOutputTracker.deserializeMapStatuses(fetchedBytes, conf)
+ MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes, conf)
+ }
+ assert(err.getMessage.contains("Unable to deserialize broadcasted output statuses"))
+ }
+ }
+
+ test("SPARK-32921: test new protocol changes fetching both Map and Merge status in single RPC") {
+ val newConf = new SparkConf
+ newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
+ newConf.set(RPC_ASK_TIMEOUT, "1") // Fail fast
+ newConf.set(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST, 10240L) // 10 KiB << 1MiB framesize
+ newConf.set(PUSH_BASED_SHUFFLE_ENABLED, true)
+ newConf.set(IS_TESTING, true)
+
+ // needs TorrentBroadcast so need a SparkContext
+ withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc =>
+ val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val rpcEnv = sc.env.rpcEnv
+ val masterEndpoint = new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, newConf)
+ rpcEnv.stop(masterTracker.trackerEndpoint)
+ rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, masterEndpoint)
+ val bitmap1 = new RoaringBitmap()
+ bitmap1.add(1)
+
+ masterTracker.registerShuffle(20, 100, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
+ (0 until 100).foreach { i =>
+ masterTracker.registerMapOutput(20, i, new CompressedMapStatus(
+ BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0), 5))
}
- assert(err.getMessage.contains("Unable to deserialize broadcasted map statuses"))
+ masterTracker.registerMergeResult(20, 0, MergeStatus(BlockManagerId("999", "mps", 1000),
+ bitmap1, 1000L))
+
+ val mapWorkerRpcEnv = createRpcEnv("spark-worker", "localhost", 0, new SecurityManager(conf))
+ val mapWorkerTracker = new MapOutputTrackerWorker(conf)
+ mapWorkerTracker.trackerEndpoint =
+ mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
+
+ val fetchedBytes = mapWorkerTracker.trackerEndpoint
+ .askSync[(Array[Byte], Array[Byte])](GetMapAndMergeResultStatuses(20))
+ assert(masterTracker.getNumAvailableMergeResults(20) == 1)
+ assert(masterTracker.getNumAvailableOutputs(20) == 100)
+
+ val mapOutput =
+ MapOutputTracker.deserializeOutputStatuses[MapStatus](fetchedBytes._1, newConf)
+ val mergeOutput =
+ MapOutputTracker.deserializeOutputStatuses[MergeStatus](fetchedBytes._2, newConf)
+ assert(mapOutput.length == 100)
+ assert(mergeOutput.length == 1)
+ mapWorkerTracker.stop()
+ masterTracker.stop()
}
}
+
+ test("SPARK-32921: unregister merge result if it is present and contains the map Id") {
+ val rpcEnv = createRpcEnv("test")
+ val tracker = newTrackerMaster()
+ tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+ new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
+ tracker.registerShuffle(10, 4, 2)
+ assert(tracker.containsShuffle(10))
+ val bitmap1 = new RoaringBitmap()
+ bitmap1.add(0)
+ bitmap1.add(1)
+ tracker.registerMergeResult(10, 0, MergeStatus(BlockManagerId("a", "hostA", 1000),
+ bitmap1, 1000L))
+
+ val bitmap2 = new RoaringBitmap()
+ bitmap2.add(5)
+ bitmap2.add(6)
+ tracker.registerMergeResult(10, 1, MergeStatus(BlockManagerId("b", "hostB", 1000),
+ bitmap2, 1000L))
+ assert(tracker.getNumAvailableMergeResults(10) == 2)
+ tracker.unregisterMergeResult(10, 0, BlockManagerId("a", "hostA", 1000), Option(0))
+ assert(tracker.getNumAvailableMergeResults(10) == 1)
+ tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(1))
+ assert(tracker.getNumAvailableMergeResults(10) == 1)
+ tracker.unregisterMergeResult(10, 1, BlockManagerId("b", "hostB", 1000), Option(5))
+ assert(tracker.getNumAvailableMergeResults(10) == 0)
+ tracker.stop()
+ rpcEnv.shutdown()
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
index e433f42..d808823 100644
--- a/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
+++ b/core/src/test/scala/org/apache/spark/MapStatusesSerDeserBenchmark.scala
@@ -19,7 +19,7 @@ package org.apache.spark
import org.apache.spark.benchmark.Benchmark
import org.apache.spark.benchmark.BenchmarkBase
-import org.apache.spark.scheduler.CompressedMapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, MergeStatus}
import org.apache.spark.storage.BlockManagerId
/**
@@ -50,7 +50,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
val shuffleId = 10
- tracker.registerShuffle(shuffleId, numMaps)
+ tracker.registerShuffle(shuffleId, numMaps, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
val r = new scala.util.Random(912)
(0 until numMaps).foreach { i =>
tracker.registerMapOutput(shuffleId, i,
@@ -66,7 +66,7 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
var serializedMapStatusSizes = 0
var serializedBroadcastSizes = 0
- val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeMapStatuses(
+ val (serializedMapStatus, serializedBroadcast) = MapOutputTracker.serializeOutputStatuses(
shuffleStatus.mapStatuses, tracker.broadcastManager, tracker.isLocal, minBroadcastSize,
sc.getConf)
serializedMapStatusSizes = serializedMapStatus.length
@@ -75,12 +75,12 @@ object MapStatusesSerDeserBenchmark extends BenchmarkBase {
}
benchmark.addCase("Serialization") { _ =>
- MapOutputTracker.serializeMapStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager,
+ MapOutputTracker.serializeOutputStatuses(shuffleStatus.mapStatuses, tracker.broadcastManager,
tracker.isLocal, minBroadcastSize, sc.getConf)
}
benchmark.addCase("Deserialization") { _ =>
- val result = MapOutputTracker.deserializeMapStatuses(serializedMapStatus, sc.getConf)
+ val result = MapOutputTracker.deserializeOutputStatuses(serializedMapStatus, sc.getConf)
assert(result.length == numMaps)
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 56684d9..126faec 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark
import java.io.File
import java.util.{Locale, Properties}
-import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService}
+import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService }
import scala.collection.JavaConverters._
@@ -33,7 +33,7 @@ import org.apache.spark.internal.config
import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MapStatus, MergeStatus, MyRDD, SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId, ShuffleIndexBlockId}
@@ -367,7 +367,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
val shuffleMapRdd = new MyRDD(sc, 1, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
val shuffleHandle = manager.registerShuffle(0, shuffleDep)
- mapTrackerMaster.registerShuffle(0, 1)
+ mapTrackerMaster.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
// first attempt -- its successful
val context1 =
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 055ee0d..707e168 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -53,7 +53,7 @@ import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, Transpo
import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExecutorDiskUtils, ExternalBlockStoreClient}
import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor}
import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv}
-import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, SparkListenerBlockUpdated}
+import org.apache.spark.scheduler.{LiveListenerBus, MapStatus, MergeStatus, SparkListenerBlockUpdated}
import org.apache.spark.scheduler.cluster.{CoarseGrainedClusterMessages, CoarseGrainedSchedulerBackend}
import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite}
import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager}
@@ -1956,7 +1956,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE
Files.write(bm1.diskBlockManager.getFile(shuffleIndex).toPath(), shuffleIndexBlockContent)
Files.write(bm2.diskBlockManager.getFile(shuffleIndex2).toPath(), shuffleIndexBlockContent)
- mapOutputTracker.registerShuffle(0, 1)
+ mapOutputTracker.registerShuffle(0, 1, MergeStatus.SHUFFLE_PUSH_DUMMY_NUM_REDUCES)
val decomManager = new BlockManagerDecommissioner(conf, bm1)
try {
mapOutputTracker.registerMapOutput(0, 0, MapStatus(bm1.blockManagerId, Array(blockSize), 0))
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org