You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2020/03/25 05:53:16 UTC

[spark] branch branch-3.0 updated: [SPARK-31207][CORE] Ensure the total number of blocks to fetch equals to the sum of local/hostLocal/remote blocks

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 5da878b  [SPARK-31207][CORE] Ensure the total number of blocks to fetch equals to the sum of local/hostLocal/remote blocks
5da878b is described below

commit 5da878b913606ed6dddee9b1b6c6c5770a8a5e45
Author: Xingbo Jiang <xi...@databricks.com>
AuthorDate: Wed Mar 25 13:19:43 2020 +0800

    [SPARK-31207][CORE] Ensure the total number of blocks to fetch equals to the sum of local/hostLocal/remote blocks
    
    ### What changes were proposed in this pull request?
    
    Assert the number of blocks to fetch equals the number of local blocks + the number of hostLocal blocks + the number of remote blocks in ShuffleBlockFetcherIterator. Also refactor the code a bit to make it easier to follow.
    
    ### Why are the changes needed?
    
    When the numbers don't match it means something is going wrong, we should fail fast.
    
    ### Does this PR introduce any user-facing change?
    
    No. This is basically code refactoring.
    
    ### How was this patch tested?
    
    Tested with existing test suites.
    
    Closes #27972 from jiangxb1987/BlockFetcher.
    
    Authored-by: Xingbo Jiang <xi...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit a03fbfbdd5758f6f2798510783c830c432ba0367)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../storage/ShuffleBlockFetcherIterator.scala      | 84 +++++++++++++---------
 1 file changed, 49 insertions(+), 35 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 2a0447d..f1a7d88 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -91,8 +91,7 @@ final class ShuffleBlockFetcherIterator(
   private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)
 
   /**
-   * Total number of blocks to fetch. This should be equal to the total number of blocks
-   * in [[blocksByAddress]] because we already filter out zero-sized blocks in [[blocksByAddress]].
+   * Total number of blocks to fetch.
    */
   private[this] var numBlocksToFetch = 0
 
@@ -290,7 +289,6 @@ final class ShuffleBlockFetcherIterator(
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
-    var numRemoteBlocks = 0
 
     val hostLocalDirReadingEnabled =
       blockManager.hostLocalDirManager != null && blockManager.hostLocalDirManager.isDefined
@@ -299,25 +297,29 @@ final class ShuffleBlockFetcherIterator(
       if (address.executorId == blockManager.blockManagerId.executorId) {
         checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
-          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)).to[ArrayBuffer])
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)))
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (hostLocalDirReadingEnabled && address.host == blockManager.blockManagerId.host) {
         checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
-          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)).to[ArrayBuffer])
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)))
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
         hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
-        numRemoteBlocks += blockInfos.size
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
+    val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
     val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
+    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
+      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
+        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks}.")
     logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
       s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
       s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
@@ -325,6 +327,40 @@ final class ShuffleBlockFetcherIterator(
     collectedRemoteRequests
   }
 
+  private def createFetchRequest(
+      blocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      curRequestSize: Long): FetchRequest = {
+    logDebug(s"Creating fetch request of $curRequestSize at $address "
+      + s"with ${blocks.size} blocks")
+    FetchRequest(address, blocks)
+  }
+
+  private def createFetchRequests(
+      curBlocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      curRequestSize: Long,
+      isLast: Boolean,
+      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
+    var retBlocks = Seq.empty[FetchBlockInfo]
+    if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
+      collectedRemoteRequests += createFetchRequest(mergedBlocks, address, curRequestSize)
+    } else {
+      mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
+        if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
+          collectedRemoteRequests += createFetchRequest(blocks, address, curRequestSize)
+        } else {
+          // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
+          // to `curBlocks`.
+          retBlocks = blocks
+          numBlocksToFetch -= blocks.size
+        }
+      }
+    }
+    retBlocks
+  }
+
   private def collectFetchRequests(
       address: BlockManagerId,
       blockInfos: Seq[(BlockId, Long, Int)],
@@ -333,32 +369,6 @@ final class ShuffleBlockFetcherIterator(
     var curRequestSize = 0L
     var curBlocks = new ArrayBuffer[FetchBlockInfo]
 
-    def createFetchRequest(blocks: Seq[FetchBlockInfo]): Unit = {
-      collectedRemoteRequests += FetchRequest(address, blocks)
-      logDebug(s"Creating fetch request of $curRequestSize at $address "
-        + s"with ${blocks.size} blocks")
-    }
-
-    def createFetchRequests(isLast: Boolean): Unit = {
-      val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
-      curBlocks = new ArrayBuffer[FetchBlockInfo]
-      if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
-        createFetchRequest(mergedBlocks)
-      } else {
-        mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
-          if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
-            createFetchRequest(blocks)
-          } else {
-            // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
-            // to `curBlocks`.
-            curBlocks = blocks
-            numBlocksToFetch -= blocks.size
-          }
-        }
-      }
-      curRequestSize = curBlocks.map(_.size).sum
-    }
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
       assertPositiveBlockSize(blockId, size)
@@ -367,12 +377,16 @@ final class ShuffleBlockFetcherIterator(
       // For batch fetch, the actual block in flight should count for merged block.
       val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
       if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        createFetchRequests(isLast = false)
+        curBlocks = createFetchRequests(curBlocks, address, curRequestSize, isLast = false,
+          collectedRemoteRequests).to[ArrayBuffer]
+        curRequestSize = curBlocks.map(_.size).sum
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
-      createFetchRequests(isLast = true)
+      curBlocks = createFetchRequests(curBlocks, address, curRequestSize, isLast = true,
+        collectedRemoteRequests).to[ArrayBuffer]
+      curRequestSize = curBlocks.map(_.size).sum
     }
   }
 
@@ -389,7 +403,7 @@ final class ShuffleBlockFetcherIterator(
   }
 
   private[this] def mergeContinuousShuffleBlockIdsIfNeeded(
-      blocks: ArrayBuffer[FetchBlockInfo]): ArrayBuffer[FetchBlockInfo] = {
+      blocks: Seq[FetchBlockInfo]): Seq[FetchBlockInfo] = {
     val result = if (doBatchFetch) {
       var curBlocks = new ArrayBuffer[FetchBlockInfo]
       val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org