You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by GitBox <gi...@apache.org> on 2021/04/13 03:46:37 UTC

[GitHub] [spark] otterc opened a new pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

otterc opened a new pull request #32140:
URL: https://github.com/apache/spark/pull/32140


   This is work in progress as it depends on in-progress jiras:
   - SPARK-32921
   - SPARK-33350
   
   ### What changes were proposed in this pull request?
   This is the shuffle fetch side change where executors can fetch local/remote merged shuffle data from shuffle services. This is needed for push-based shuffle - SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
   
   This change introduces new messages between clients and the external shuffle service:
   
   1. `MergedBlockMetaRequest`: The client sends this to external shuffle to get the meta information for a merged block. The response to this is one of these :
     - `MergedBlockMetaSuccess` : contains request id, number of chunks, and a {{ManagedBuffer}} which is a {{FileSegmentBuffer}} backed by the merged block meta file.
     - `RpcFailure`: this is sent back to client in case of failure. This is an existing message.
   
   2. `FetchShuffleBlockChunks`: This is similar to `FetchShuffleBlocks` message but it is to fetch merged shuffle chunks instead of blocks.
   
   ### Why are the changes needed?
   These changes are needed for push-based shuffle. Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
   
   ### Does this PR introduce _any_ user-facing change?
   When push-based shuffle is turned on then that will fetch merged shuffle block chunks from remote shuffle service. The client logs will indicate this.
   
   ### How was this patch tested?
   Added unit tests.
   The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
   We have already verified the functionality and the improved performance as documented in the SPIP doc.
   
   Lead-authored-by: Chandni Singh chsingh@linkedin.com
   Co-authored-by: Ye Zhou yezhou@linkedin.com
   Co-authored-by: Min Shen mshen@linkedin.com
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870101000






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-868939032


   The github actions test failure looks unrelated, let me try jenkins anyway


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655034874



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)

Review comment:
       I have made this change but now this happens when the iterator processes ` PushMergedLocalMetaFetchResult`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648496347



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1074,8 +1337,13 @@ object ShuffleBlockFetcherIterator {
    * A request to fetch blocks from a remote BlockManager.
    * @param address remote BlockManager to fetch from.
    * @param blocks Sequence of the information for blocks to fetch from the same address.
+   * @param hasMergedBlocks true if this request contains merged blocks; false if it contains

Review comment:
       All blocks are merged blocks. Will change the comment.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648819624



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)

Review comment:
       That's actually my concern. I noticed the check inside `collectFetchRequests` too. Couldn't we only check once before the if condition and abandon the check inside `collectFetchRequests`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656715909



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       Merged chunks are going to be 2MB. Only the last chunk in a merged file can be smaller than 2mb. So this can throw  exception. But I am handling the exception in the following catch block to initiate a fallback. Do you see any issues with this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649435207



##########
File path: core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -22,31 +22,40 @@ import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
 
+import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
 

Review comment:
       > a) deserialization failure results in initiating fallback.
   
   Yes, the test `fallback to original shuffle block when a merged block chunk is corrupt` does this. It tests the fallback when the shuffle merged chunk is deserialized during processing of `SuccessFetchResult`.
   
   > b) fetch failure of both merged block and fallback block should get reported to driver as fetch failure.
   
   When there is a fetch failure of a merged block, then the iterator falls back to fetch original blocks. So, we don't report that to the driver because the task didn't fail because of it. It tries to fetch the original blocks that make up that merged blocks.
   
   I have added tests for the various conditions that trigger fallback but these simulate fetches of all original blocks to be successful.  I haven't added a test which triggers fallback but the iterator fails to fetch an original block and that throws FetchFailedException. That follows the existing code but I will still add this test for it.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656715909



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       Merged chunks are going to be 2MB. Only the last chunk in a merged file can be smaller than 2mb. So this can still throw  exception. But I am handling the exception in the following catch block to initiate a fallback. Do you see any issues with this?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657129548



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       > What I am trying to understand is, we will be initiating a fallback and discarding merge even though there is nothing really wrong here - other than the fact that chunk was too small to decompress - right ? (in case chunk was split at a boundary which causes decompression to fail).
   
   A shuffle merged chunk contains shuffle blocks in its entirety. It will never contain a partial shuffle block. This is documented for the configuration `minChunkSizeInMergedShuffleFile`
   ```
     /**
      * The minimum size of a chunk when dividing a merged shuffle file into multiple chunks during
      * push-based shuffle.
      * A merged shuffle file consists of multiple small shuffle blocks. Fetching the
      * complete merged shuffle file in a single response increases the memory requirements for the
      * clients. Instead of serving the entire merged file, the shuffle service serves the
      * merged file in `chunks`. A `chunk` constitutes few shuffle blocks in entirety and this
      * configuration controls how big a chunk can get. A corresponding index file for each merged
      * shuffle file will be generated indicating chunk boundaries.
      */
     public int minChunkSizeInMergedShuffleFile() {
       return Ints.checkedCast(JavaUtils.byteStringAsBytes(
           conf.get("spark.shuffle.server.minChunkSizeInMergedShuffleFile", "2m")));
     }
     ```
   So if this fails for a shuffle chunk, it would be because the shuffle chunk was corrupt.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645727804



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       `Long.toArray` usage for regular block  here is existing code. In the existing code this was be done at line 141
   ```
   long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
   ```
   I just had to move it around. Should I still change this?
   
   I added `Ints.toArray` for shuffleChunks. The way I can make it clean/concise is by adding separate methods for converting a Set<Number> to long[] and int[]. I can add them to a utility class if it exists in this module or create a new one. Let me know what you think




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870985435


   
   Refer to this link for build results (access rights to CI server needed): 
   https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder-K8s/44912/
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640210756



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +824,66 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                numBlocksProcessed += pushBasedFetchHelper
+                  .initiateFallbackBlockFetchForMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>

Review comment:
       Note to self: Most of this is as before. Have added conditions for shuffleChunks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648493846



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)

Review comment:
       Actually no reason. The method was just returning the original blocks if `doBatchFetch = false`, so was calling it. I will change it to not call it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645664320



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       I couldn't think of a good generic name for it. Would `primaryIdAndSecondaryIds` work? I used primaryId and secondaryId in OneForOneBlockFetcher. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646807905



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       This is invoked for each fetch request from the client. A fetch request is for multiple blocks.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656714752



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,18 +745,21 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
-              shuffleMetrics.incLocalBlocksFetched(1)
-              shuffleMetrics.incLocalBytesRead(buf.size)
-            } else {
-              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
-              shuffleMetrics.incRemoteBytesRead(buf.size)
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
-                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
-              }
-              shuffleMetrics.incRemoteBlocksFetched(1)
-              bytesInFlight -= size
-            }
+           if (hostLocalBlocks.contains(blockId -> mapIndex) ||
+             pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
+             // It is a host local block or a local shuffle chunk
+             shuffleMetrics.incLocalBlocksFetched(1)
+             shuffleMetrics.incLocalBytesRead(buf.size)
+           } else {
+             // Could be a remote shuffle chunk or remote block
+             numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+             shuffleMetrics.incRemoteBytesRead(buf.size)
+             if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+               shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+             }
+             shuffleMetrics.incRemoteBlocksFetched(1)
+             bytesInFlight -= size
+           }

Review comment:
       Actually @mridulm. I didn't even change the indentation here. It is same as before. I just added lines 749, 750, and 754. I will remove the comment at 754 and see if this interface stops showing that I have changed this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660751432



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration

Review comment:
       I have made this change but also added a var for counting num of hostLocalBlocks which is needed for the assertions. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660721217



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration

Review comment:
       Would need to do something that for finding out the number of hostLocalBlocks in the assertions before as well.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870260954


   > Sorry for the delay. I'll do a review today. BTW, are there any other necessary magnet PRs that have to be merged for the 3.2 cut/release?
   
   There are 2 pending tasks which are necessary for Magnet:
   - SPARK-35546: [#33078](https://github.com/apache/spark/pull/33078). This one is ready for review
   - SPARK-32923: [#33034](https://github.com/apache/spark/pull/33034). This is WIP


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-859274853


   This is #32140 @otterc, you meant #32811 :-)
   I did not check the PR id and followed the last comment twice :-)


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-818410050


   Can one of the admins verify this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648685526



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    })
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        if (isNotExecutorOrMergedLocal(address)) {

Review comment:
       Added comments. PTAL.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648595988



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1074,8 +1337,13 @@ object ShuffleBlockFetcherIterator {
    * A request to fetch blocks from a remote BlockManager.
    * @param address remote BlockManager to fetch from.
    * @param blocks Sequence of the information for blocks to fetch from the same address.
+   * @param hasMergedBlocks true if this request contains merged blocks; false if it contains

Review comment:
       I also changed the name of this param to `forMergedMetas`. It makes it clearer




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648683194



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648285195



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $mergedLocalBlockBytes")

Review comment:
       I guess you actually want the `pushMergedBlockInfos.map(_.size).sum` instead of the accumulated  `mergedLocalBlockBytes`?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],

Review comment:
       nit: `mutable.LinkedHashSet` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -376,48 +418,62 @@ final class ShuffleBlockFetcherIterator(
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"[${context.taskAttemptId()}] Getting $blocksToFetchCurrentIteration " +

Review comment:
       If you try to log the task info here, could you try to follow the task name format?
   https://github.com/apache/spark/blob/7d8181b62f17a202ba584c7bba65b61ec4724db2/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L558

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>

Review comment:
       log error?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()

Review comment:
       nit: `blocksToRequest` -> `blocksToFetch`?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)

Review comment:
       In case of `doBatchFetch = false`, `mergeContinuousShuffleBlockIdsIfNeeded` simply returns unchanged blockInfos. Why we need to call it?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1074,8 +1337,13 @@ object ShuffleBlockFetcherIterator {
    * A request to fetch blocks from a remote BlockManager.
    * @param address remote BlockManager to fetch from.
    * @param blocks Sequence of the information for blocks to fetch from the same address.
+   * @param hasMergedBlocks true if this request contains merged blocks; false if it contains

Review comment:
       "contains" or all blocks must be merged blocks?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>

Review comment:
       nit: 2 indents

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -835,8 +1013,11 @@ final class ShuffleBlockFetcherIterator(
 
     def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {

Review comment:
       How about:
   
   ```scala
   logDebug("Sending request for %d blocks (%s) from %s".format(
         request.blocks.size, Utils.bytesToString(request.size), request.address.hostPort))
   if (hasMergedBlocks) {
    pushBasedFetchHelper.sendFetchMergedStatusRequest(request)
   } else {
    sendRequest(request)
    numBlocksInFlightPerAddress(remoteAddress) = 
      numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size
   }
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {

Review comment:
       nit: 
   ```scala
   req.blocks.foreach { block =>
    ...
   }
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)

Review comment:
       I'm wondering why don't we check block sizes in the first places since we'd check for any type of blocks anyway?
   e.g.,
   ```scala
       for ((address, blockInfos) <- blocksByAddress) {
         checkBlockSizes(blockInfos)
         ....
       }
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()

Review comment:
       ```suggestion
       val blocksToRequest = new ArrayBuffer[(BlockId, Long, Int)]()
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       How about making it a trait? And I think we can put it into a separate file since it's not small.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       Hm..is it possible there's only `FetchRequest(hasMergedBlocks)` at the beginning? In that case, it seems to cause the fetching process to hang.
   
   We probably need to call `fetchUpToMaxBytes()` here if `reqsInFlight=0`.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    })
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        if (isNotExecutorOrMergedLocal(address)) {

Review comment:
       Could you add some comments to help to understand the logic here better?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))

Review comment:
       nit: `sizeMap(shuffleId, reduceId)` -> `sizeMap((shuffleId, reduceId))`
   
   (I know there's scala compiler doesn't support the former syntax for the tuple.)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870079218


   @otterc Can you fix the conflict please?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-856069303


   @otterc Given the volume of the PR, does it cleanly separate out into ESS side and client side ?
   If it does, we can merge the former first and then the latter.
   
   If not, let us keep it as is.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660721217



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration

Review comment:
       Would need to do something for finding out the number of hostLocalBlocks in the assertions before as well.
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655050720



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,16 +744,29 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+            if (pushBasedFetchHelper.isMergedLocal(address)) {
+              // It is a local merged block chunk
+              assert(blockId.isShuffleChunk)
+              shuffleMetrics.incLocalBlocksFetched(pushBasedFetchHelper.getNumberOfBlocksInChunk(
+                blockId.asInstanceOf[ShuffleBlockChunkId]))

Review comment:
       I have changes this as well. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648520952



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       > Hm..is it possible there's only FetchRequest(hasMergedBlocks) at the beginning? In that case, it seems to cause the fetching process to hang.
   
   It will not cause the fetch process to hang if there is just one FetchRequest with merged blocks.
   Consider this example that if there is a FetchRequest for a merged block `ShuffleBlock(0, -1, 0)`,
   - the iterator will send out the request to fetch the metadata for this block in `PushBasedFetchHelper.sendFetchMergedStatusRequest`. 
   - The iterator will wait for a response in the result queue at `results.take()`.
   - Once it receives a response, which is either `MergedBlocksMetaFetchResult` or `MergedBlocksMetaFailedFetchResult`, it adds more FetchRequests to the fetch queue and sets `result = null`.
   - `fetchUpToMaxBytes()` is always called after processing the response.
   - Since `result = null`, while loop repeats and waits again for a response in the result queue.  
   
   I will also add a UT for this case just to verify this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656661352



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       We seem to be assuming that `numChunks` == `bitmaps.length` (here, in `PushMergedRemoteMetaFetchResult`, etc)
   Can they be different ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {

Review comment:
       I am trying to understand why `takeWhile` here while `foreach` in `cachedMergerDirs.isDefined` above and in `Failure` case below.
   What happens to the remaining blocks when `fetchPushMergedLocalBlock` returns `false` ? They seem to be getting dropped silently ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       For a small enough chunk, cant this not throw exception ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +879,85 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, _, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              for (chunkId <- bufs.indices) {
+                val buf = bufs(chunkId)

Review comment:
       ```suggestion
                 bufs.zipWithIndex { case (buf, chunkId) =>
   ```
   
   Avoid the O(n) traversal of Seq

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, areMergedBlocks) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
       curBlocks = createFetchRequests(curBlocks, address, isLast = true,
-        collectedRemoteRequests)
+        collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+        forMergedMetas = areMergedBlocks)
       curRequestSize = curBlocks.map(_.size).sum

Review comment:
       This is end of method, and `curRequestSize` is a local variable :-)
   

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")

Review comment:
       super nit: remove 's'

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchPushMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local push-merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")

Review comment:
       Move this comment to before iterating over `pushMergedLocalBlocks` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)

Review comment:
       Add a `get` method instead of creating a copy of `executorIdToLocalDirsCache` ?
   Something like this in `HostLocalDirManager`:
   ```
     private[spark] def getCachedHostLocalDirs(executorId: String): Option[Array[String]] =
       executorIdToLocalDirsCache.synchronized {
         Option(executorIdToLocalDirsCache.getIfPresent(executorId))
       }
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,18 +745,21 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
-              shuffleMetrics.incLocalBlocksFetched(1)
-              shuffleMetrics.incLocalBytesRead(buf.size)
-            } else {
-              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
-              shuffleMetrics.incRemoteBytesRead(buf.size)
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
-                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
-              }
-              shuffleMetrics.incRemoteBlocksFetched(1)
-              bytesInFlight -= size
-            }
+           if (hostLocalBlocks.contains(blockId -> mapIndex) ||
+             pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
+             // It is a host local block or a local shuffle chunk
+             shuffleMetrics.incLocalBlocksFetched(1)
+             shuffleMetrics.incLocalBytesRead(buf.size)
+           } else {
+             // Could be a remote shuffle chunk or remote block
+             numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+             shuffleMetrics.incRemoteBytesRead(buf.size)
+             if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+               shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+             }
+             shuffleMetrics.incRemoteBlocksFetched(1)
+             bytesInFlight -= size
+           }

Review comment:
       Most of this diff is due to change in indentation - revert back to 2 space ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-869453190


   +CC @Ngone51 Any further comments/suggestions ?
   I was planning to merge this tomorrow, thanks !


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656298069



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       @mridulm In the new changes, I have moved the part in `fetchMergedLocalBlock` which was modifying the map to be handled just by the task thread when it processes `PushMergedLocalMetaFetchResult`. So, I have changed this map back to a regular hashmap because now it just gets modified by task thread. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649449168



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)

Review comment:
       Thanks for pointing this out.
   This method is always called for a merged shuffle block or chunk where  `isMergedShuffleBlockAddress(address)` is going to be `true`. It is called from 2 places in iterator where we we just want to know if the merged block/chunk is remote or not. I am going to change this method name and implementation so it does what it is intended for. 
   
   We do follow the same method for fetching merged local blocks which was introduced in SPARK-27651.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648845832



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       A  lot of methods in `PushBasedFetchHelper` also needs access to the iterator instance.  It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue when it receives the meta response.
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   
   It also needs access to the `shuffleClient`, `blockManager`, and `mapOutputTracker`. Most of the methods in this class will access one or more of these instances.
   
   Also, each instance of helper contains `chunksMetaMap`. In order to make `PushBasedFetchHelper` a trait, this would then would move to `ShuffleBlockFetcherIterator` and then passed to each method in the helper that needs it.
   
   IMO, it seem better to create an instance of `PushBasedFetchHelper` per iterator instance. Otherwise, all the methods of `PushBasedFetchHelper` will have way more arguments.
   
   I find this class similar to  the existing `BufferReleasingInputStream` in the iterator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660698644



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1403,67 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle chunk.
+   * 2) Local push-merged block.
+   *
+   * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks.
+   *
+   * @param blockId block id
+   * @param address BlockManager that the push-merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a remote push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param blockSize size of each push-merged block.
+   * @param bitmaps bitmaps for every chunk.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a remote push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a local push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param bitmaps bitmaps for every chunk.
+   * @param localDirs local directories where the push-merged shuffle files are storedl
+   */
+  private[storage] case class PushMergedLocalMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      bitmaps: Array[RoaringBitmap],
+      localDirs: Array[String],
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult

Review comment:
       Didn't realize that the `FetchResult` was changed.   Earlier it had a `blockId` field which is why I was using `DUMMY_SHUFFLE_BLOCK_ID`. Will change it. Thanks for pointing out.
   ```
   private[storage] sealed trait FetchResult {
       val blockId: BlockId
       val address: BlockManagerId
     }
     ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645901943



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);

Review comment:
       done

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649505327



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          val pushMergedBlockInfos = blockInfos.map(
+            info => FetchBlockInfo(info._1, info._2, info._3))
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          val size = pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $size")
+          mergedLocalBlockBytes += size
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+      s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    if (hostLocalBlocksCurrentIteration.nonEmpty) {
+      this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration
+    }
     collectedRemoteRequests
   }
 
   private def createFetchRequest(
       blocks: Seq[FetchBlockInfo],
-      address: BlockManagerId): FetchRequest = {
+      address: BlockManagerId,
+      forMergedMetas: Boolean = false): FetchRequest = {
     logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address "
       + s"with ${blocks.size} blocks")
-    FetchRequest(address, blocks)
+    FetchRequest(address, blocks, forMergedMetas)
   }
 
   private def createFetchRequests(
       curBlocks: Seq[FetchBlockInfo],
       address: BlockManagerId,
       isLast: Boolean,
-      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = {
-    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch)
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): Seq[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch)

Review comment:
       > Is mergeContinuousShuffleBlockIdsIfNeeded relevant for merged blocks/chunks ?
   
   No, it is not relevant for merged blocks/chunks. For both merged blocks/chunks, I am passing `enabledBatchFetch = false` so `mergeContinuousShuffleBlockIdsIfNeeded` returns the passed in blocks.
   
   I am not seeing any side effects of reusing this method for merged blocks/chunks. IIUC, this method enforces the limit of `maxBlocksInFlightPerAddress` for a  FetchRequest and is the one that modifies `numBlocksToFetch` for remote requests.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648845832



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       A  lot of methods in `PushBasedFetchHelper` also needs access to the iterator instance.  It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue when it receives the meta response.
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   
   It also needs access to the `shuffleClient`, `blockManager`, and `mapOutputTracker`. Most of the methods in this class will access one or more of these instances.
   
   Also, each instance of helper contains `chunksMetaMap`. In order to make `PushBasedFetchHelper` a trait, this would then would move the `ShuffleBlockFetcherIterator` and then passed to each method in the helper that needs it.
   
   IMO, it seem better to create an instance of `PushBasedFetchHelper` per iterator instance. Otherwise, all the methods of `PushBasedFetchHelper` will have way more arguments.
   
   I find this class similar to  the existing `BufferReleasingInputStream` in the iterator.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       I took another look at it and I think making it a `ConcurrentMap` and swapping the order of those lines fix any issues related to multiple thread. PTAL.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       Resolving this one.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657585363



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       If we are asserting on `bitmaps.size() == numChunks`, why are we passing around `numChunks` ?
   I am fine with keeping `numChunks` as part of the protocol given forward compatibility possibilities - but rest of the code, as it stands today, can leverage this ? (with ser/deser check in `PushMergedLocalMetaFetchResult` to enforce this requirement for now)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646807905



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       This is invoked  by the client when it creates a fetch request. A fetch request is for multiple blocks. Since it is per request and not block, I think I will just leave it this way as it doesn't seem too frequent.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646728233



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       If we dont expect it to be null, make it a `Preconditions.checkNotNull` and remove the check then ?
   Not sure if this is an artifact of some earlier iteration of the code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657585363



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       If we are asserting on `bitmaps.size() == numChunks`, why are we passing around `numChunks` ?
   I am fine with keeping `numChunks` as part of the protocol given forward compatibility possibilities - but rest of the code, as it stands today, can leverage this ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646727343



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       Yeah, that sounds good - just add a note on the variable on what is primary and secondary for the two cases supported.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648845832



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       The problem is that `PushBasedFetchHelper` also needs access to the iterator instance. It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue. 
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   This is why it is a helper class similar to  the existing `BufferReleasingInputStream` and `ShuffleFetchCompletionListener`. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645900747



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656791069



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {

Review comment:
       I have added a UT to catch this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656731887



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, areMergedBlocks) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
       curBlocks = createFetchRequests(curBlocks, address, isLast = true,
-        collectedRemoteRequests)
+        collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+        forMergedMetas = areMergedBlocks)
       curRequestSize = curBlocks.map(_.size).sum

Review comment:
       Oh yes :) I will remove this line.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646733707



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().

Review comment:
       In general, given protocol evolution issues, I am fine with sending the length as part of the serde.
   I was trying to understand if this assumption is very strict or whether things can evolve in future such that this assumption is broken.
   
   Will resolve the conversation, thanks for clarifying.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645900543



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870079218






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649479840



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       Same as above. It is always accessed by the task thread.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657584917



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       Thanks for clarifying ! This addresses my concern.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649479384



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)

Review comment:
       No because `chunksMetaMap` is always modified/read by the single task thread. It is accessed either during iterator initialization or when `iterator.next()` is called.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655050720



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,16 +744,29 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+            if (pushBasedFetchHelper.isMergedLocal(address)) {
+              // It is a local merged block chunk
+              assert(blockId.isShuffleChunk)
+              shuffleMetrics.incLocalBlocksFetched(pushBasedFetchHelper.getNumberOfBlocksInChunk(
+                blockId.asInstanceOf[ShuffleBlockChunkId]))

Review comment:
       I have changed this as well. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 edited a comment on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870235923


   Sorry for the delay. I'll do a review today. BTW, are there any other necessary magnet PRs that have to be merged for the 3.2 cut/release?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-854796576


   > Took an initial pass, yet to look at `ShuffleBlockFetcherIterator` or test suites.
   > I am wondering, given the volume, whether we want to split between ESS side and client side. Thoughts ?
   
   Thanks Mridul for reviewing!
   My thoughts about splitting this change is that it completely encapsulates the fetch-side changes so it is easier to understand how the new messages introduced on the client side are being handled on the server-side. One of the feedbacks we got last year was that we broke things up in a way that made it difficult to understand.
   
   That being said, I am still okay to break this change into client/sever PRs if that makes the review easier for the reviewers.
   cc. @mridulm @Ngone51 @Victsm @tgravescs @attilapiros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645880908



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().

Review comment:
       This is a strong assumption. For a `reduceIds[i]`, each chunkId in chunksIds[i] represents a shuffleChunk  `shuffleChunk_<shuffleId>_reduceIds[i]_chunksIds[i][j]` which is being requested.
   This is similar to the existing FetchShuffleBlock message. There are assertions on the server side as well to check this.
   
   > Do we see a future evolution where this can break
   
   Not really.  At least I can't think of any that would change this. This just represents all the chunkIds for a particular reduceIds that are being requested which I don't think will change.
   
   > encode and decode do not assume this currently 
   
   That's right. I can avoid writing `chunkIdsLen`. It is just that I mimicked the code in `FetchShuffleBlocks` which does the same. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646736029



##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r

Review comment:
       I was thinking more of `-?\\d+` - but that is ok, let us ignore this given every other regex is using the `[0-9]` anyway.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640210756



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +824,66 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                numBlocksProcessed += pushBasedFetchHelper
+                  .initiateFallbackBlockFetchForMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>

Review comment:
       Note to self: Most of this is as before. Have added conditions for shuffleChunks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656858416



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)

Review comment:
       What I am trying to understand is, we will be initiating a fallback and discarding merge even though there is nothing really wrong here - other than the fact that chunk was too small to decompress - right ? (in case chunk was split at a boundary which causes decompression to fail).
   Want to make sure I am not missing something here.
   
   The 2mb is configurable, so we can have `maxBytesInFlight` and chunk sizes of different sizes - do we enforce any constraint on these to prevent this sort of issue ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-856524008


   Removed all the changes from here which are now part of SPARK-35671


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] asfgit closed pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
asfgit closed pull request #32140:
URL: https://github.com/apache/spark/pull/32140


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645902345



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660221758



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,35 +355,56 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight
+    // in order to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._3).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback)
+          .contains(address.executorId)) {

Review comment:
       nit: mutable.HashSet -> Set
   Also, pull this `Set` out of the loop (you can remove the `fallback` variable and populate the `Set` directly instead) ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649482176



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {

Review comment:
       > We have possibility of only ShuffleBlockId or ShuffleBlockChunkId in this method right ?
   
   Yes. Will add assertions.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc edited a comment on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-859181282


   This is still dependent on the changes in https://github.com/apache/spark/pull/32140 which has the protocol side of changes.
   
   Correction: This is dependent on #32811 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645872258



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException {
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException {
       channel.close();

Review comment:
       I reverted back the code. We do want the same behavior, that is, to fail remaining blocks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640321817



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -720,7 +720,7 @@ private[spark] class MapOutputTrackerMaster(
     }
   }
 
-  def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) {
+  def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus): Unit = {

Review comment:
       Note to reviewers: this was added as part of push-based shuffle so just fixing the below warning:
   ```
   [warn] /home/runner/work/spark/spark/core/src/main/scala/org/apache/spark/MapOutputTracker.scala:723:79: [deprecation @  | origin= | version=2.13.0] procedure syntax is deprecated: instead, add `: Unit =` to explicitly declare `registerMergeResult`'s return type
   [warn]   def registerMergeResult(shuffleId: Int, reduceId: Int, status: MergeStatus) {
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648594657



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       I will work on separating this into its own file.  For making this a trait, do you mean that we define APIs for `PushBasedFetchHelper` so there could be different implementations? If yes, I am not sure it will add much value. Right now, this implementation and the methods in this class are very closely tied to the implementation of push merge shuffle. The API is not generic. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648845832



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       The problem is that `PushBasedFetchHelper` also needs access to the iterator instance.  It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue. 
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   
   It also needs access to the `shuffleClient`, `blockManager`, and `mapOutputTracker`.
   This is why it is a helper class similar to  the existing `BufferReleasingInputStream` and `ShuffleFetchCompletionListener`. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649457152



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>

Review comment:
       So, one reason to just catch `Throwable` is that for any kind of error/exception during reading the meta information, we just want to ensure fallback is triggered. Even though `meta.readChunkBitmaps()` can just throw IOException just found it safer here to catch Throwable so that fallback would be trigger.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc edited a comment on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-854796576


   > Took an initial pass, yet to look at `ShuffleBlockFetcherIterator` or test suites.
   > I am wondering, given the volume, whether we want to split between ESS side and client side. Thoughts ?
   
   Thanks Mridul for reviewing!
   My thoughts about splitting this change is that this PR completely encapsulates the fetch-side changes so it is easier to understand how the new messages introduced on the client side are being handled on the server-side. One of the feedbacks we got last year was that we broke things up in a way that made it difficult to understand. 
   On the server side, this PR mostly adds the wiring needed in  `ExternalBlockHandler` to server the merged meta/data requests from `RemoteBlockPushResolver`.
   
   That being said, I am still okay to break this change into client/sever PRs if that makes the review easier for the reviewers.
   cc. @mridulm @Ngone51 @Victsm @tgravescs @attilapiros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660766903



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +878,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading local push-merged data, we fallback to

Review comment:
       IIUC, `getLocalMergedBlockData` doesn't read the data file indeed. How would the exception thrown by the data file?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 edited a comment on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870235923


   Sorry for the delay. I'll do a review today. BTW, are there any other necessary magnet PRs that have to be merged for the 3.2 release?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655045134



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          FallbackOnMergedFailureFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type:
+   * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]]
+   * 2) [[ShuffleBlockFetcherIterator.FallbackOnMergedFailureFetchResult]]
+   * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]]
+   *
+   * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that
+   * failed to fetch.
+   * It makes a call to the map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      blockId match {
+        case shuffleBlockId: ShuffleBlockId =>
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+        case _ =>
+          val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+          val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId)
+          assert(chunkBitmap != null)
+          // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+          // fallback not only for that particular remote shuffle block chunk but also for all the
+          // pending block chunks that belong to the same host. The reason for doing so is that it
+          // is very likely that the subsequent requests for merged block chunks from this host will
+          // fail as well. Since, push-based shuffle is best effort and we try not to increase the
+          // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the
+          // fetchRequests queue.
+          if (isMergedBlockAddressRemote(address)) {
+            // Fallback for all the pending fetch requests
+            val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+            if (pendingShuffleChunks.nonEmpty) {
+              pendingShuffleChunks.foreach { pendingBlockId =>
+                logInfo(s"Falling back immediately for merged block $pendingBlockId")
+                val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId)
+                assert(bitmapOfPendingChunk != null)
+                chunkBitmap.or(bitmapOfPendingChunk)
+              }
+              // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed
+              blocksProcessed += pendingShuffleChunks.size

Review comment:
       I have made this change as well. `initiateFallbackFetchForPushMergedBlock` is now decrementing `numBlocksToFetch`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645902038



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660322354



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1403,67 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle chunk.
+   * 2) Local push-merged block.
+   *
+   * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks.
+   *
+   * @param blockId block id
+   * @param address BlockManager that the push-merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a remote push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param blockSize size of each push-merged block.
+   * @param bitmaps bitmaps for every chunk.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a remote push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a local push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param bitmaps bitmaps for every chunk.
+   * @param localDirs local directories where the push-merged shuffle files are storedl
+   */
+  private[storage] case class PushMergedLocalMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      bitmaps: Array[RoaringBitmap],
+      localDirs: Array[String],
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult

Review comment:
       `blockId` is never used?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -57,6 +59,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
  *                        block, which indicate the index in the map stage.
  *                        Note that zero-sized blocks are already excluded, which happened in
  *                        [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
+ * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching the original blocks if
+ *                        we fail to fetch shuffle chunks when push based shuffle is enabled.

Review comment:
       nit: indents

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +

Review comment:
       maybe, to be consistent with `push-merged-local` in all places?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +878,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading local push-merged data, we fallback to

Review comment:
       I think we can only see the exception with reading the push-merged index file rather than data file, right?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration

Review comment:
       Shall we reuse `hostLocalBlocksByExecutor` here? e.g.,
   
   ```scala
   this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
         .flatMap { infos => infos.map(info => (info._1, info._3)) }
   ```
   
   so we can get rid of `hostLocalBlocksCurrentIteration`.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1063,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = {

Review comment:
       This looks like only decrease `numBlocksToFetch` indeed. Shall we rename it to `decreaseNumBlocksToFetch`?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +

Review comment:
       `... doesn't equal to the sum of ...` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1063,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure related to a push-merged block or shuffle chunk.
+   * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+   * fallback.
+   */
+  private[storage] def fallbackFetch(
+      originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val originalHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr,
+      originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(originalRemoteReqs)
+    logInfo(s"Started ${originalRemoteReqs.size} fallback remote requests for push-merged")

Review comment:
       "Started" -> "Created"?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645904465



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;

Review comment:
       I added `_` in the prefixes. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649603473



##########
File path: core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -22,31 +22,40 @@ import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
 
+import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
 

Review comment:
       I have added test for (b) `failed to fetch merged block as well as fallback block should throw a FetchFailedException`

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure with a shuffle merged block/chunk.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // If there is any fall back block that's a local block, we get them here. The original
+    // invocation to fetchLocalBlocks might have already returned by this time, so we need
+    // to invoke it again here.
+    fetchLocalBlocks(fallbackLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(fallbackMergedLocalBlocks.isEmpty,
+      "There should be zero merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host as the block chunk that had
+   * a fetch failure.
+   *
+   * @return set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleBlockChunk(block: BlockId): Boolean = {
+      val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+      chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId
+    }
+
+    def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+      val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+      fetchRequestsToRemove ++= queue.dequeueAll(req => {
+        val firstBlock = req.blocks.head
+        firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+          sameShuffleBlockChunk(firstBlock.blockId)
+      })
+      fetchRequestsToRemove.foreach(req => {
+        removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])
+      })
+    }
+
+    filterRequests(fetchRequests)
+    val defRequests = deferredFetchRequests.remove(address).orNull
+    if (defRequests != null) {
+      filterRequests(defRequests)
+      if (defRequests.nonEmpty) {
+        deferredFetchRequests(address) = defRequests
+      }
+    }

Review comment:
       I made this change with rest of the changes, so resolving it.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure with a shuffle merged block/chunk.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // If there is any fall back block that's a local block, we get them here. The original
+    // invocation to fetchLocalBlocks might have already returned by this time, so we need
+    // to invoke it again here.

Review comment:
       This comment didn't make much sense so I have just removed it. 

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch

Review comment:
       done. 

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)

Review comment:
       I renamed this to `isMergedBlockAddressRemote` and also modified the comment.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       Added the assertion.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {

Review comment:
       Done.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +360,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $mergedLocalBlockBytes")
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {

Review comment:
       done

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       @mridulm That's a very good catch. Sorry, missed that. Netty thread will be adding to it for the local block. I will work on fixing this.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       I have changed this to a `ConcurrentHashMap` and switched the order of these lines in `fetchMergedLocalBlock`
   ```
           chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
           iterator.addToResultsQueue(
             SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
               isNetworkReqDone = false))
   ```
   I think this should solve any issues because all we want is the guarantee that for a `SuccessFetchResult` related to a shuffleChunkId, the meta for that shuffleChunkId is present in `chunksMetaMap`. 
   I will think a bit more about this tomorrow. 

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       I have changed this to ConcurrentHashMap and using `get`. Have also kept the assert check.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       I have changed this to a `ConcurrentHashMap` and switched the order of these lines in `fetchMergedLocalBlock`
   ```
           chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
           iterator.addToResultsQueue(
             SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
               isNetworkReqDone = false))
   ```
   I think this should solve any issues because all we want is the guarantee that while processing a `SuccessFetchResult` related to a shuffleChunkId, the meta for that shuffleChunkId is present in `chunksMetaMap`. 
   I will think a bit more about this tomorrow. 

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1394,54 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle block chunk.
+   * 2) Local merged block data.
+   *
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,

Review comment:
       Made this change.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>

Review comment:
       Made this change so resolving it.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       A  lot of methods in `PushBasedFetchHelper` also needs access to the iterator instance.  It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue when it receives the meta response.
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   
   It also needs access to the `shuffleClient`, `blockManager`, and `mapOutputTracker`. Most of the methods in this class will access one or more of these instances.
   
   Also, each instance of helper contains `chunksMetaMap`. In order to make `PushBasedFetchHelper` a trait, this would then moved into the `ShuffleBlockFetcherIterator` and then passed to each method in the helper that needs it.
   
   IMO, it seem better to create an instance of `PushBasedFetchHelper` per iterator instance. Otherwise, all the methods of `PushBasedFetchHelper` will have way more arguments.
   
   I find this class similar to  the existing `BufferReleasingInputStream` in the iterator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645674391



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       It's not going to be null. The implementation either throws an exception or returns a `FileSegmentManagedBuffer`. However, even the other iterators `ManagerBufferIterator` and `ShuffleManagedBufferIterator` check if the block is null even though `blockManager.getBlockData` and `blockManager.getContinuousBlocksData` will not return null.
   
   Please let me know if I should remove this check for `ShuffleChunkManagedBufferIterator`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc edited a comment on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-856524008


   Removed all the changes from here which are now part of SPARK-35671 and are not needed for this WIP PR to compile.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649690353



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],

Review comment:
       Looks like my comments overlapped with @Ngone51's comments a lot :-)

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>

Review comment:
       The problem with catching `Throwable` is it might suppress cases which should not be handled - like ignoring some Error like OOM for example (ofcourse, an OOM will typically result in other failures as well : but I am simply illustrating the issue).
   Catching Exception should suffice here.
   There are indeed some cases which have slipped through (catching Throwable), but those are not common.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       If we are sure it cant be null, replace `orNull` with `get` instead ? It makes the semantics clearer.
   I am fine with leaving it with `assert` check as well.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       If we are sure it cant be `null`, replace `orNull` with `get` instead ? It makes the semantics clearer.
   I am fine with leaving it with `assert` check as well.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       `fetchMergedLocalBlock` is not invoked from the task thread right (from `fetchMergedLocalBlocks`) ?
   That updates `chunksMetaMap` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1394,54 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle block chunk.
+   * 2) Local merged block data.
+   *
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,

Review comment:
       Stick Merge in the class name - just in case we have other forms of Fallback in future which need specialized handling.

##########
File path: core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -22,31 +22,40 @@ import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
 
+import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
 

Review comment:
       To clarify, what I meant was that on both failing, driver should see it as a fetch failure (wont see the merge part - but will see a fetch failure).




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870315289


   The test failures are unrelated.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648857212



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       Oh, I see. I misread it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660704248



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +878,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading local push-merged data, we fallback to

Review comment:
       There could be IOException while reading either the index or the data file which would be caught here. I can mention explicitly `push-merged data/index file`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins removed a comment on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
AmplabJenkins removed a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-818410050


   Can one of the admins verify this patch?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648705513



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +487,51 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
       assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+              curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+            curRequestSize = curBlocks.map(_.size).sum

Review comment:
       Offline review comment from @mridulm
   > the fetch request is not fetching the actual data, but just the list of chunks to be fetched for that merged block, right ? Given the size of that response is so low, why are we considering size of the block for updating curRequestSize there then 
   
   This is a good point so I will remove this check and found some other discrepancies as well. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645900987



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648820571



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       > * fetchUpToMaxBytes() is always called after processing the response.
   
   I doubt this point. But I'd check the test first since I could miss something.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649472914



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       `chunkBitmap` should not be `null`. I think the code is missing that assertion currently which I will add.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645679706



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {

Review comment:
       Good point. I missed the `_`.  I will add `_` to the prefixes so nothing changes wrt to existing feature. When `_` is added, then checking for `SHUFFLE_CHUNK_PREFIX` would not be redundant.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-833027629


   Will resolve conflicts when this [PR](https://github.com/apache/spark/pull/32389) is merged as.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r647851101



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],

Review comment:
       nit: Use either `mutable.LinkedHash*` or import the class and use that directly ?
   We have multiple forms in this PR.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap

Review comment:
       nit: Move `}.toMap` to next line

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          val pushMergedBlockInfos = blockInfos.map(
+            info => FetchBlockInfo(info._1, info._2, info._3))
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          val size = pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $size")
+          mergedLocalBlockBytes += size
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+      s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    if (hostLocalBlocksCurrentIteration.nonEmpty) {

Review comment:
       super nit: remove the `nonEmpty` check.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          val pushMergedBlockInfos = blockInfos.map(
+            info => FetchBlockInfo(info._1, info._2, info._3))
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          val size = pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $size")
+          mergedLocalBlockBytes += size
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+      s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    if (hostLocalBlocksCurrentIteration.nonEmpty) {
+      this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration
+    }
     collectedRemoteRequests
   }
 
   private def createFetchRequest(
       blocks: Seq[FetchBlockInfo],
-      address: BlockManagerId): FetchRequest = {
+      address: BlockManagerId,
+      forMergedMetas: Boolean = false): FetchRequest = {

Review comment:
       Remove the default value for `forMergedMetas` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>

Review comment:
       Why catch `Throwable` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()

Review comment:
       There can be concurrent mods to this Map, handle MT-safety ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)

Review comment:
       For merged blocks, why are we doing this ?
   Currently, this is a noop anyway.
   
   We can remove `pushMergedBlockInfos` entirely here.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)

Review comment:
       nit: use `==`

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")

Review comment:
       nit: `logInfo` here ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)

Review comment:
       Do this (and caller tree) support SPARK-27651 ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)

Review comment:
       Now `checkBlockSizes` is being done for all the cases ... while earlier, it was not done for the last `else`.
   Did you look into whether this is ok ?
   
   +CC @attilapiros who did this change initially.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)

Review comment:
       Can we have NPE here ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          val pushMergedBlockInfos = blockInfos.map(
+            info => FetchBlockInfo(info._1, info._2, info._3))
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          val size = pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $size")
+          mergedLocalBlockBytes += size
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +

Review comment:
       Note: Here we are assuming `localBlocks` is empty when method was invoked.

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+        // fallback not only for that particular remote shuffle block chunk but also for all the
+        // pending block chunks that belong to the same host. The reason for doing so is that it is
+        // very likely that the subsequent requests for merged block chunks from this host will fail
+        // as well. Since, push-based shuffle is best effort and we try not to increase the delay
+        // of the fetches, we immediately fallback for all the pending shuffle chunks in the
+        // fetchRequests queue.
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)

Review comment:
       Any possibility of race here ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          val pushMergedBlockInfos = blockInfos.map(
+            info => FetchBlockInfo(info._1, info._2, info._3))
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          val size = pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $size")
+          mergedLocalBlockBytes += size
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " +
+      s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    if (hostLocalBlocksCurrentIteration.nonEmpty) {
+      this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration
+    }
     collectedRemoteRequests
   }
 
   private def createFetchRequest(
       blocks: Seq[FetchBlockInfo],
-      address: BlockManagerId): FetchRequest = {
+      address: BlockManagerId,
+      forMergedMetas: Boolean = false): FetchRequest = {
     logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address "
       + s"with ${blocks.size} blocks")
-    FetchRequest(address, blocks)
+    FetchRequest(address, blocks, forMergedMetas)
   }
 
   private def createFetchRequests(
       curBlocks: Seq[FetchBlockInfo],
       address: BlockManagerId,
       isLast: Boolean,
-      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = {
-    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch)
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): Seq[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch)

Review comment:
       Is `mergeContinuousShuffleBlockIdsIfNeeded` relevant for merged blocks/chunks ?
   If not, any side effects of doing this ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch

Review comment:
       `foundMoreBlocksToFetch` -> `incrementNumBlocksToFetch` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, areMergedBlocks) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
       curBlocks = createFetchRequests(curBlocks, address, isLast = true,
-        collectedRemoteRequests)
+        collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+        forMergedMetas = areMergedBlocks)
       curRequestSize = curBlocks.map(_.size).sum

Review comment:
       nit: Unrelated to this PR, but drop this `sum` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Throwable =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {

Review comment:
       We have possibility of only `ShuffleBlockId` or `ShuffleBlockChunkId` in this method right ?
   Add that as a precondition and check for `isInstanceOf[ShuffleBlockId]` instead of `isShuffle` ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure with a shuffle merged block/chunk.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // If there is any fall back block that's a local block, we get them here. The original
+    // invocation to fetchLocalBlocks might have already returned by this time, so we need
+    // to invoke it again here.

Review comment:
       Can we rephrase this comment ? The comments (`"
   The original invocation to fetchLocalBlocks might have already returned by this time"`) makes it sound like a timing issue and so potentially a race.
   In reality, initial `fetchLocalBlocks` was for the initial request, and for each failure to fetch merged blocks/chunks, we have to redo the exercise for that set.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1054,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure with a shuffle merged block/chunk.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // If there is any fall back block that's a local block, we get them here. The original
+    // invocation to fetchLocalBlocks might have already returned by this time, so we need
+    // to invoke it again here.
+    fetchLocalBlocks(fallbackLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(fallbackMergedLocalBlocks.isEmpty,
+      "There should be zero merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host as the block chunk that had
+   * a fetch failure.
+   *
+   * @return set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleBlockChunk(block: BlockId): Boolean = {
+      val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+      chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId
+    }
+
+    def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+      val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+      fetchRequestsToRemove ++= queue.dequeueAll(req => {
+        val firstBlock = req.blocks.head
+        firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+          sameShuffleBlockChunk(firstBlock.blockId)
+      })
+      fetchRequestsToRemove.foreach(req => {
+        removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])
+      })
+    }
+
+    filterRequests(fetchRequests)
+    val defRequests = deferredFetchRequests.remove(address).orNull
+    if (defRequests != null) {
+      filterRequests(defRequests)
+      if (defRequests.nonEmpty) {
+        deferredFetchRequests(address) = defRequests
+      }
+    }

Review comment:
       nit:
   ```suggestion
       deferredFetchRequests.get(address).foreach(defRequests => {
         filterRequests(defRequests)
         if (defRequests.isEmpty) deferredFetchRequests.remove(address)
       })
   ```

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +360,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $mergedLocalBlockBytes")
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {

Review comment:
       While we are at it, make it a `Set` ?

##########
File path: core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -22,31 +22,40 @@ import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
 
+import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
 

Review comment:
       Also add tests for:
   a) deserialization failure results in initiating fallback.
   b) fetch failure of both merged block and fallback block should get reported to driver as fetch failure.
   
   Are these handled already ?

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1394,54 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle block chunk.
+   * 2) Local merged block data.
+   *
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,

Review comment:
       We are not ignoring the result as such, but using it to initiate a fallback. 
   `IgnoreFetchResult` -> `RetriableMergeFailureResult` ? Or something better ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-868938908


   Jenkins, test this please


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649419059



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,77 +355,118 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)

Review comment:
       @mridulm it was done earlier for the last `else` as well but it was part of `collectFetchRequests`. 
   Based on this conversation
   https://github.com/apache/spark/pull/32140#discussion_r648301802, I moved it before the `if/else if`. I also removed it from `collectFetchRequests` so that for remote blocks, it's not done twice.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656714752



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,18 +745,21 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
-              shuffleMetrics.incLocalBlocksFetched(1)
-              shuffleMetrics.incLocalBytesRead(buf.size)
-            } else {
-              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
-              shuffleMetrics.incRemoteBytesRead(buf.size)
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
-                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
-              }
-              shuffleMetrics.incRemoteBlocksFetched(1)
-              bytesInFlight -= size
-            }
+           if (hostLocalBlocks.contains(blockId -> mapIndex) ||
+             pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
+             // It is a host local block or a local shuffle chunk
+             shuffleMetrics.incLocalBlocksFetched(1)
+             shuffleMetrics.incLocalBytesRead(buf.size)
+           } else {
+             // Could be a remote shuffle chunk or remote block
+             numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+             shuffleMetrics.incRemoteBytesRead(buf.size)
+             if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+               shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+             }
+             shuffleMetrics.incRemoteBlocksFetched(1)
+             bytesInFlight -= size
+           }

Review comment:
       Actually @mridulm. I didn't even change the indentation here. It same as before. I just added lines 749, 750, and 754. I will remove the comment at 754 and see if this interface stops showing that I have changed this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645277961



##########
File path: common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
##########
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to find the meta information for the specified merged block. The meta information
+ * contains the number of chunks in the merged blocks and the maps ids in each chunk.
+ *
+ * @since 3.2.0
+ */
+public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage {
+  public final long requestId;
+  public final String appId;
+  public final int shuffleId;
+  public final int reduceId;
+
+  public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
+    super(null, false);
+    this.requestId = requestId;
+    this.appId = appId;
+    this.shuffleId = shuffleId;
+    this.reduceId = reduceId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.MergedBlockMetaRequest;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + Encoders.Strings.encodedLength(appId) + 8;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(requestId);
+    Encoders.Strings.encode(buf, appId);
+    buf.writeInt(shuffleId);
+    buf.writeInt(reduceId);
+  }
+
+  public static MergedBlockMetaRequest decode(ByteBuf buf) {
+    long requestId = buf.readLong();
+    String appId = Encoders.Strings.decode(buf);
+    int shuffleId = buf.readInt();
+    int reduceId = buf.readInt();
+    return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof MergedBlockMetaRequest) {
+      MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
+      return requestId == o.requestId && Objects.equal(appId, o.appId)
+        && shuffleId == o.shuffleId && reduceId == o.reduceId;

Review comment:
       nit: move the appId check to last.

##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A basic callback. This is extended by {@link RpcResponseCallback} and
+ * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
+ * can be handled in {@link TransportResponseHandler} a similar way.
+ *
+ * @since 3.2.0
+ */
+public interface BaseResponseCallback {

Review comment:
       nit: I dont have good suggestions, but any thoughts on renaming this interface better ?
   Thoughts @Ngone51, @attilapiros ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {

Review comment:
       Reviewer note: `Iterator` contract requires that `next` should check if `hasNext` is true - else throw `NoSuchElementException`.
   Unfortunately, the other iterators in `ExternalBlockHandler` are also not doing it ...

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));

Review comment:
       Add a one line note on what `blockIdParts[3]` can be.

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r

Review comment:
       nit: `\\d+` instead ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -128,24 +134,23 @@ protected void handleMessage(
       BlockTransferMessage msgObj,
       TransportClient client,
       RpcResponseCallback callback) {
-    if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
+    if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) {
       final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
       try {
         int numBlockIds;
         long streamId;
-        if (msgObj instanceof FetchShuffleBlocks) {
-          FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
+        if (msgObj instanceof AbstractFetchShuffleBlocks) {
+          AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj;
           checkAuth(client, msg.appId);
-          numBlockIds = 0;
-          if (msg.batchFetchEnabled) {
-            numBlockIds = msg.mapIds.length;
+          numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks();

Review comment:
       `getNumBlocks` makes this code cleaner.

##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
##########
@@ -199,14 +200,31 @@ public void handle(ResponseMessage message) throws Exception {
       }
     } else if (message instanceof RpcFailure) {
       RpcFailure resp = (RpcFailure) message;
-      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+      BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
       if (listener == null) {
         logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
           resp.requestId, getRemoteAddress(channel), resp.errorString);
       } else {
         outstandingRpcs.remove(resp.requestId);
         listener.onFailure(new RuntimeException(resp.errorString));
       }
+    } else if (message instanceof MergedBlockMetaSuccess) {
+      MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message;
+      MergedBlockMetaResponseCallback listener =
+        (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId);
+      if (listener == null) {
+        logger.warn(
+          "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not"
+            + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size());
+        resp.body().release();
+      } else {
+        outstandingRpcs.remove(resp.requestId);
+        try {
+          listener.onSuccess(resp.getNumChunks(), resp.body());
+        } finally {
+          resp.body().release();

Review comment:
       nit: move `resp.body().release()` to try/finally for this entire else block.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -189,9 +194,14 @@ protected void handleMessage(
     } else if (msgObj instanceof GetLocalDirsForExecutors) {
       GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
       checkAuth(client, msg.appId);
-      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
+      String[] execIdsForBlockResolver = Arrays.stream(msg.execIds)
+        .filter(execId -> !SHUFFLE_MERGER_IDENTIFIER.equals(execId)).toArray(String[]::new);
+      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId,
+        execIdsForBlockResolver);
+      if (Arrays.asList(msg.execIds).contains(SHUFFLE_MERGER_IDENTIFIER)) {
+        localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
+      }

Review comment:
       ```suggestion
         Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds);
         boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER);
         Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, execIdsForBlockResolver);
         if (fetchMergedBlockDirs) {
           localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
         }
   ```
   
   With a corresponding change in `blockManager.getLocalDirs` to take a set of executor ids instead of array.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);

Review comment:
       Just to clarify, we are not modifying old fetch protocol at all.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }

Review comment:
       nit:
   
   ```suggestion
         BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, id -> new BlocksInfo());
   ```
   
   and remove the get below

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       Do we want to rename this variable (here and in constructor) and this method given the overloading of map/reduce vs reduce/chunk now ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);

Review comment:
       Move this assertion into constructor itself.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {

Review comment:
       Here we are assuming all the blocks are either chunks or all are blocks.
   That is not the validation we are performing in `areShuffleBlocksOrChunks` - where a mix of both can pass.
   
   Do we want to make it stricter in `areShuffleBlocksOrChunks` ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       When would `block` be `null` ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       nit: `Longs.toArray` is a bit expensive - same for `Ints.toArray` below.
   If we can avoid it, while keeping code clean/concise, that would be preferable (there are couple of other locations in this PR which use these api's).

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {

Review comment:
       super nit: As coded, checking for `SHUFFLE_CHUNK_PREFIX` here is redundant - though I am fine with it for clarity.
   Btw, we are avoiding a '_' suffix check here.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       We can have `failRemainingBlocks` delegate to `failSingleBlockChunk` now ?
   ```
     private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
       Arrays.stream(failedBlockIds).forEach(blockId -> failSingleBlockChunk(blockId, e));
     }
   ```

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;

Review comment:
       ```suggestion
     return Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && !blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
   ```
   
   
   Review note: startsWith `SHUFFLE_BLOCK_PREFIX` is superset of startsWith `SHUFFLE_CHUNK_PREFIX` - though I am fine with keeping them separate in interest of clarity.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);

Review comment:
       Iterate over `primaryIdToBlocksInfo.entrySet` instead ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException {
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException {
       channel.close();

Review comment:
       What is the expected behavior if there are exceptions closing channel ? (the failure perhaps being due to `onData` throwing exception, for example)

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));

Review comment:
       Update the comment above/add a one line note on what `blockIdParts[4]` can be

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().

Review comment:
       How strong is this assumption ? Do we see a future evolution where this can break ? Or is it tied to the protocol in nontrivial ways ?
   As an example - `encode` and `decode` do not assume this currently (we could have avoided writing `chunkIdsLen` if they did)

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);

Review comment:
       ```suggestion
       for (String blockId : blocksInfoByPrimaryId.blockIds) {
           this.blockIds[blockIdIndex++] = blockId;
       }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655047442



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.

Review comment:
       I have made changes to the iterator, push-based fetch helper, and the iterator suite to refer all push-based merged blocks as "pushMerged" and also just using "shuffle chunks".




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655046357



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          FallbackOnMergedFailureFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type:
+   * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]]
+   * 2) [[ShuffleBlockFetcherIterator.FallbackOnMergedFailureFetchResult]]
+   * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]]
+   *
+   * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that
+   * failed to fetch.
+   * It makes a call to the map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      blockId match {
+        case shuffleBlockId: ShuffleBlockId =>
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+        case _ =>
+          val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+          val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId)
+          assert(chunkBitmap != null)
+          // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+          // fallback not only for that particular remote shuffle block chunk but also for all the
+          // pending block chunks that belong to the same host. The reason for doing so is that it
+          // is very likely that the subsequent requests for merged block chunks from this host will
+          // fail as well. Since, push-based shuffle is best effort and we try not to increase the
+          // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the
+          // fetchRequests queue.
+          if (isMergedBlockAddressRemote(address)) {
+            // Fallback for all the pending fetch requests
+            val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+            if (pendingShuffleChunks.nonEmpty) {
+              pendingShuffleChunks.foreach { pendingBlockId =>
+                logInfo(s"Falling back immediately for merged block $pendingBlockId")
+                val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId)
+                assert(bitmapOfPendingChunk != null)
+                chunkBitmap.or(bitmapOfPendingChunk)
+              }
+              // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed
+              blocksProcessed += pendingShuffleChunks.size
+            }
+          }
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap)
+      }
+    iterator.fetchFallbackBlocks(fallbackBlocksByAddr)

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-856099709


   > @otterc Given the volume of the PR, does it cleanly separate out into ESS side and client side ?
   > If it does, we can merge the former first and then the latter.
   > 
   > If not, let us keep it as is.
   
   Yes @mridulm. It will cleanly separate out into ESS side and client side. I will work on splitting the PR


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r652951329



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.

Review comment:
       I think it is good to differentiate between push-merged blocks and batch block. 
   But, I prefer calling push-based merged blocs as `push-merged block` and the batch blocks could be called as `batched blocks`.  To me this makes it clear which feature this is referring to. 
   
   I will clean up the comments and variable names related to push-based shuffle to make this consistent.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640209621



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +360,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $mergedLocalBlockBytes")
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {

Review comment:
       Note to self: Haven't changed this line. Just added the `else if` so here it appears to be added by me.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648537705



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)

Review comment:
       For a merged blocks that are remote, this check is performed in `collectFetchRequests`. So, if we do it before the condition ` if (address.host == blockManager.blockManagerId.host)`, then it will be done twice for remote merged blocks.
   I think this is also a reason that currently `checkBlockSizes()`  is being called explicitly for each block type because for remote blocks (the last else block), the size is validated in `collectFetchRequests`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646728233



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       If we dont expect it to be null, make it a `Preconditions.checkNotNull` and remove the check then ?
   Not sure if null check elsewhere is an artifact of some earlier iteration of the code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660275905



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,35 +355,56 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight
+    // in order to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._3).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback)
+          .contains(address.executorId)) {

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-867341179


   Looks good to me, will wait for @Ngone51 to finish his review. Thanks for the changes and clarifications @otterc !


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649548771



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1394,54 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle block chunk.
+   * 2) Local merged block data.
+   *
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,

Review comment:
       Will `Retriable` be confusing that this request itself is retried? Can call it `FallbackOnFailureFetchResult`. Let me know what you think?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646731962



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       How often is this invoked ? If not that common, let us keep it as is (clarity is more important) - if it is **very** frequent, let us make it faster with a util method.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656718852



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {

Review comment:
       Yeah this is a bug. Will fix it. Usually there is just a single merged block for a particular reduce partition that an iterator fetches which is why none of our UTs caught it either. Will add a UT as well.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-849206077


   I have rebased the changes against the latest master. It only depends on https://github.com/apache/spark/pull/32140 but it interfaces at just 1 place with that PR so it can still be reviewed in parallel. 
   Given the size of this change and the feature freeze in early July, would really appreciate if folks can help with the review.
   @mridulm @Victsm @Ngone51 @tgravescs @attilapiros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660774462



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +878,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading local push-merged data, we fallback to

Review comment:
       Changed the comment so resolving the conversation.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645709830



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));

Review comment:
       batchFetchEnabled will only be true for regular shuffle blocks and not merged shuffle block chunks. I will add this note here and also add the comment for `blockIdParts[4]`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656718852



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local push-merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          pushMergedLocalBlocks.takeWhile {

Review comment:
       Yeah this is a bug. Will fix it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645901855



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException {
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException {
       channel.close();

Review comment:
       Reverted back this change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655044846



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)

Review comment:
       Resolving this. All the updates to `numBlocksToFetch` as well as `chunksMetaMap` are done by task thread now.  So also reverted the change that made `chunksMetaMap` concurrent.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640210756



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +824,66 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                numBlocksProcessed += pushBasedFetchHelper
+                  .initiateFallbackBlockFetchForMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>

Review comment:
       Note to self: Most of this is as before. Have added only conditions for shuffleChunks




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-864697969


   This PR is not a WIP any more. @mridulm @Ngone51 I have addressed all the review comments so far. PTAL.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660221758



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,35 +355,56 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight
+    // in order to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._3).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback)
+          .contains(address.executorId)) {

Review comment:
       nit: mutable.HashSet -> Set
   Also, pull this `Set` out of the loop (you can remove the `fallback` variable and populate the `Set` directly instead) ?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640209621



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +360,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)
+          val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+            blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false)
+          numBlocksToFetch += pushMergedBlockInfos.size
+          mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId)
+          mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum
+          logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " +
+            s"of size $mergedLocalBlockBytes")
+        } else {
+          remoteBlockBytes += blockInfos.map(_._2).sum
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (
+        Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {

Review comment:
       Note to self: Haven't changed this line. Just added the `else if` so this appears to be added by me.

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1391,297 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
 }
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+    shuffleId: Int,
+    reduceId: Int,
+    blockSize: Long,
+    numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    })
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged local blocks dirs/blocks..
+   */
+  private def fetchMergedLocalBlocks(
+    hostLocalDirManager: HostLocalDirManager,
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(
+              IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+    blockId: BlockId,
+    localDirs: Array[String],
+    blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.foundMoreBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed
+   * to fetch.
+   * It calls out to map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      if (blockId.isShuffle) {
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+      } else {
+        val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+        val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull
+        if (isNotExecutorOrMergedLocal(address)) {
+          // Fallback for all the pending fetch requests
+          val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+          if (pendingShuffleChunks.nonEmpty) {
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logWarning(s"Falling back immediately for merged block $pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap =
+                chunksMetaMap.remove(pendingBlockId).orNull
+              assert(bitmapOfPendingChunk != null)
+              chunkBitmap.or(bitmapOfPendingChunk)
+            }
+            // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed
+            blocksProcessed += pendingShuffleChunks.size
+          }
+        }
+        mapOutputTracker.getMapSizesForMergeResult(
+          shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap)
+      }
+    iterator.fetchFallbackBlocks(fallbackBlocksByAddr)
+    blocksProcessed
+  }
+}

Review comment:
       Add end of line

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1391,297 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
 }
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+    shuffleId: Int,
+    reduceId: Int,
+    blockSize: Long,
+    numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    })
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged local blocks dirs/blocks..

Review comment:
       Nit: `..`

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1391,297 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
 }
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+    private val iterator: ShuffleBlockFetcherIterator,
+    private val shuffleClient: BlockStoreClient,
+    private val blockManager: BlockManager,
+    private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /** A map for storing merged block shuffle chunk bitmap */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId)
+  }
+
+  /**
+   * Returns true if the address is not of executor local or merged local block. false otherwise.
+   */
+  def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = {
+    (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) ||
+      (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId)
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap(blockId).getCardinality
+  }
+
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  def createChunkBlockInfosFromMetaResponse(
+    shuffleId: Int,
+    reduceId: Int,
+    blockSize: Long,
+    numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] =
+      new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToRequest
+  }
+
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId,
+            sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case _: Throwable =>
+            iterator.addToResultsQueue(
+              MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach(block => {
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    })
+  }
+
+  // Fetch all outstanding merged local blocks
+  def fetchAllMergedLocalBlocks(
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged local blocks dirs/blocks..
+   */
+  private def fetchMergedLocalBlocks(
+    hostLocalDirManager: HostLocalDirManager,
+    mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)

Review comment:
       Nit: should import `BlockManagerId.SHUFFLE_MERGER_IDENTIFIER`

##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -712,38 +824,66 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
-            // the corruption is later, we'll still detect the corruption later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                numBlocksProcessed += pushBasedFetchHelper
+                  .initiateFallbackBlockFetchForMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if
+              // the corruption is later, we'll still detect the corruption later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>

Review comment:
       Note to self: Most of it as before. Have added only conditions for shuffleChunks

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -87,6 +97,32 @@ case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) exte
   override def name: String = "shufflePush_" + shuffleId + "_" + mapIndex + "_" + reduceId
 }
 
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleMergedBlockId(appId: String, shuffleId: Int, reduceId: Int) extends BlockId {

Review comment:
       These are also part OF SPARK-33350.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870662778


   I left some minor comments. I think we're ready to merge after addressing these comments.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870085695


   @otterc Can you also please see if #33109 is relevant to push based shuffle ? Thx


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645870074



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       Thanks for bringing this up @mridulm. This is something I missed as the push-based shuffle code evolved. We don't need this change and I will revert it.
   We do want to fail all the remaining chunkIds as well so that they get retried by `RetryingBlockFetcher`. 
   Even with this code, the remaining chunkIds are going to be retried by `RetryingBlockFetcher` because that not only retries the fetch of the failed block but also the fetch of all the outstanding blocks.  In addition `RetryingBlockFetchListener.initiateRetry` also changes the currentListener so any invocation of `onBlockFetchFailure()` on an old instance of `RetryingBlockFetchListener` would do nothing because that checks whether the instance is the current one or not.
   
   In short, this change is unnecessary here and is not needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645901032



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655045474



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1047,81 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure for a shuffle merged block/chunk.
+   * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+   * fallback.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // fetch all the fallback blocks that are local.
+    fetchLocalBlocks(fallbackLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(fallbackMergedLocalBlocks.isEmpty,
+      "There should be zero merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host as the block chunk that had
+   * a fetch failure.
+   * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+   * fallback.
+   *
+   * @return set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleBlockChunk(block: BlockId): Boolean = {

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648821768



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       It's not about API. It's an engineering concern. Personally, I think it's more natural to call push-based related functions directly in `ShuffleBlockFetcherIterator` without bothering an instance object. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-853476109


   Gentle ping to help review this PR @tgravescs @attilapiros @Ngone51 @mridulm @Victsm @zhouyejoe 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649419829



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],

Review comment:
       This is already addressed so resolving it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648847338



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       Actually, this is the existing code which I haven't modified. The while loop inside iterator.next() is as below, so `fetchUpToMaxBytes` is always called after a response is matched and processed.
   ```
       while (result == null) {
         val startFetchWait = System.nanoTime()
         result = results.take()
         val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait)
         shuffleMetrics.incFetchWaitTime(fetchWaitTime)
   
         result match {...}
    
         // Send fetch requests up to maxBytesInFlight
         fetchUpToMaxBytes()
       }
       ```
    




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870235923


   Sorry for the delay. I'll do a review today. BTW, are there any other necessary mgnet PRs that have to be merged for the 3.2 release?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655046020



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -871,6 +1047,81 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = {
+    numBlocksToFetch += moreBlocksToFetch
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch
+   * failure for a shuffle merged block/chunk.
+   * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+   * fallback.
+   */
+  private[storage] def fetchFallbackBlocks(
+      fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = {
+    val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val fallbackHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr,
+      fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(fallbackRemoteReqs)
+    logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged")
+    // fetch all the fallback blocks that are local.
+    fetchLocalBlocks(fallbackLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(fallbackMergedLocalBlocks.isEmpty,
+      "There should be zero merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host as the block chunk that had
+   * a fetch failure.
+   * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates
+   * fallback.
+   *
+   * @return set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleBlockChunk(block: BlockId): Boolean = {
+      val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+      chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId
+    }
+
+    def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+      val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+      fetchRequestsToRemove ++= queue.dequeueAll(req => {
+        val firstBlock = req.blocks.head
+        firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+          sameShuffleBlockChunk(firstBlock.blockId)
+      })
+      fetchRequestsToRemove.foreach(req => {
+        removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])
+      })

Review comment:
       I made this change with others.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r655045176



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)
+      for (chunkId <- bufs.indices) {
+        val buf = bufs(chunkId)
+        buf.retain()
+        val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId,
+          shuffleBlockId.reduceId, chunkId)
+        chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId))
+        iterator.addToResultsQueue(
+          SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf,
+            isNetworkReqDone = false))
+      }
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a local merged block, we fallback to
+        // fetch the original unmerged blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching local merged block, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          FallbackOnMergedFailureFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false))
+        false
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type:
+   * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]]
+   * 2) [[ShuffleBlockFetcherIterator.FallbackOnMergedFailureFetchResult]]
+   * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]]
+   *
+   * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that
+   * failed to fetch.
+   * It makes a call to the map output tracker to get the list of original blocks for the
+   * given merged blocks, split them into remote and local blocks, and process them
+   * accordingly.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle block chunk from local merged shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle block chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk
+   *    (local or remote).
+   *
+   * @return number of blocks processed
+   */
+  def initiateFallbackBlockFetchForMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Int = {
+    assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId])
+    logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId")
+    // Increase the blocks processed since we will process another block in the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    var blocksProcessed = 1
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] =
+      blockId match {
+        case shuffleBlockId: ShuffleBlockId =>
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+        case _ =>
+          val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+          val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId)
+          assert(chunkBitmap != null)
+          // When there is a failure to fetch a remote merged shuffle block chunk, then we try to
+          // fallback not only for that particular remote shuffle block chunk but also for all the
+          // pending block chunks that belong to the same host. The reason for doing so is that it
+          // is very likely that the subsequent requests for merged block chunks from this host will
+          // fail as well. Since, push-based shuffle is best effort and we try not to increase the
+          // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the
+          // fetchRequests queue.
+          if (isMergedBlockAddressRemote(address)) {
+            // Fallback for all the pending fetch requests
+            val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address)
+            if (pendingShuffleChunks.nonEmpty) {
+              pendingShuffleChunks.foreach { pendingBlockId =>
+                logInfo(s"Falling back immediately for merged block $pendingBlockId")
+                val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId)
+                assert(bitmapOfPendingChunk != null)
+                chunkBitmap.or(bitmapOfPendingChunk)
+              }
+              // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed
+              blocksProcessed += pendingShuffleChunks.size
+            }
+          }
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap)
+      }
+    iterator.fetchFallbackBlocks(fallbackBlocksByAddr)

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645895089



##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r

Review comment:
       This is to support a merged map id which has `-1` as mapId. Is the suggestion to replace `-?` with `\\d+`? That doesn't work or maybe I misunderstood.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-866974510


   @otterc could you update the PR description? Looks like it's outdated.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649514612



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, areMergedBlocks) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
       curBlocks = createFetchRequests(curBlocks, address, isLast = true,
-        collectedRemoteRequests)
+        collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+        forMergedMetas = areMergedBlocks)
       curRequestSize = curBlocks.map(_.size).sum

Review comment:
       We do want the sum of the sizes of all the blocks in `curBlocks` so I think the `sum` is needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
AmplabJenkins commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870942257


   
   Refer to this link for build results (access rights to CI server needed): 
   https://amplab.cs.berkeley.edu/jenkins//job/SparkPullRequestBuilder/140388/
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-856287691


   Created [SPARK-35671](https://issues.apache.org/jira/browse/SPARK-35671). The PR is https://github.com/apache/spark/pull/32811.
   
   This PR will depend on the changes in SPARK-35671.
   cc. @mridulm @Ngone51 @Victsm @tgravescs @attilapiros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646965372



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       I have made this change in this PR: https://github.com/apache/spark/pull/32811
   So resolving the conversation here




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657585363



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       If we are asserting on `bitmaps.size() == numChunks`, why are we passing around `numChunks` ?
   I am fine with keeping `numChunks` as part of the protocol given forward compatibility possibilities - but rest of the code, as it stands today, can leverage this ? (with ser/deser check in `PushMergedRemoteMetaFetchResult` and `PushMergedLocalMetaFetchResult` to enforce this requirement for now)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648682648



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       I have added a UT as well to verify this  `iterator has just 1 merged block and fails to fetch the meta`. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645870074



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       Thanks for bringing this up @mridulm. This is something I missed as the push-based shuffle code evolved. We don't need this change and I will revert it.
   We do want to fail all the remaining chunkIds as well so that they get retried by `RetryingBlockFetcher`. 
   Even with this code, the remaining chunkIds are going to be retried by `RetryingBlockFetcher` because that not only retries the fetch of the failed block but also the fetch of all the outstanding blocks.  In addition `RetryingBlockFetchListener.initiateRetry` also changes the currentListener so any invocation of `onBlockFetchFailure()` on an old instance of `RetryingBlockFetchListener` would do nothing because that checks whether the instance is the current on.
   
   In short, this change is unnecessary here and is not needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646965608



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       Made this change in the PR https://github.com/apache/spark/pull/32811. So resolving the conversation here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-818409735


   Adding @Victsm @mridulm @zhouyejoe 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648520952



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +908,43 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) =>
+          if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock(
+            blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps,
+        address, _) =>
+          // The original meta request is processed so we decrease numBlocksToFetch by 1. We will
+          // collect new chunks request and the count of this is added to numBlocksToFetch in
+          // collectFetchReqsFromMergedBlocks.
+          numBlocksToFetch -= 1
+          val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, numChunks, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null

Review comment:
       > Hm..is it possible there's only FetchRequest(hasMergedBlocks) at the beginning? In that case, it seems to cause the fetching process to hang.
   
   It will not cause the fetch process to hang if there is just a FetchRequest with merged blocks.
   Consider this example that if there is a FetchRequest for a merged block `ShuffleBlock(0, -1, 0)`,
   - the iterator will send out the request to fetch the metadata for this block in `PushBasedFetchHelper.sendFetchMergedStatusRequest`. 
   - The iterator will wait for a response in the result queue at `results.take()`.
   - Once it receives a response, which is either `MergedBlocksMetaFetchResult` or `MergedBlocksMetaFailedFetchResult`, it adds more FetchRequests to the fetch queue and sets `result = null`.
   - `fetchUpToMaxBytes()` is always called after processing the response.
   - Since `result = null`, while loop repeats and waits again for a response in the result queue.  
   
   I will also add a UT for this case just to verify this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648567584



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -376,48 +418,62 @@ final class ShuffleBlockFetcherIterator(
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3))
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         remoteBlockBytes += blockInfos.map(_._2).sum
         collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      mergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of merged-local blocks ${mergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"[${context.taskAttemptId()}] Getting $blocksToFetchCurrentIteration " +

Review comment:
       IIUC `info.get.id` = `$index.$attemptNumber` and the $index in not available in taskContext. I will just remove this change from the log. It is not related to push-based shuffle. Though it is helpful for debugging in general to have the TID logged here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646728233



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       If we dont expect it to be null, make it a precondition and remove the check then ?
   Not sure if this is an artifact of some earlier iteration of the code.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645717463



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);

Review comment:
       I will make this change. This was existing code so didn't want to change it much.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660275905



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,35 +355,56 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight
+    // in order to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._3).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback)
+          .contains(address.executorId)) {

Review comment:
       Done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-853476109


   Gentle ping to help review this PR @tgravescs @attilapiros @Ngone51 @mridulm @Victsm @zhouyejoe 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648845832



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       A  lot of methods in `PushBasedFetchHelper` also needs access to the iterator instance.  It needs to work with the iterator to be able to:
   1. add results to the iterator's `result` queue when it receives the meta response.
   2. updates number of blocks to fetch.
   3. fetch fallback blocks when there is a fallback and this in turn removes some pending blocks from `fetchRequests`.
   
   It also needs access to the `shuffleClient`, `blockManager`, and `mapOutputTracker`. Most of the methods in this class will access one or more of these instances.
   
   IMO, it seem better to create an instance of `PushBasedFetchHelper` per iterator instance. Otherwise, all the methods of `PushBasedFetchHelper` will have way more arguments.
   
   I find this class similar to  the existing `BufferReleasingInputStream` in the iterator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870967312






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870101000


   @mridulm Resolved the conflict.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc edited a comment on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-854796576


   > Took an initial pass, yet to look at `ShuffleBlockFetcherIterator` or test suites.
   > I am wondering, given the volume, whether we want to split between ESS side and client side. Thoughts ?
   
   Thanks Mridul for reviewing!
   My thoughts about splitting this change is that this PR completely encapsulates the fetch-side changes so it is easier to understand how the new messages introduced on the client side are being handled on the server-side. One of the feedbacks we got last year was that we broke things up in a way that made it difficult to understand.
   
   That being said, I am still okay to break this change into client/sever PRs if that makes the review easier for the reviewers.
   cc. @mridulm @Ngone51 @Victsm @tgravescs @attilapiros 


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645718534



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);

Review comment:
       Same here.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645277961



##########
File path: common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
##########
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to find the meta information for the specified merged block. The meta information
+ * contains the number of chunks in the merged blocks and the maps ids in each chunk.
+ *
+ * @since 3.2.0
+ */
+public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage {
+  public final long requestId;
+  public final String appId;
+  public final int shuffleId;
+  public final int reduceId;
+
+  public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) {
+    super(null, false);
+    this.requestId = requestId;
+    this.appId = appId;
+    this.shuffleId = shuffleId;
+    this.reduceId = reduceId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.MergedBlockMetaRequest;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + Encoders.Strings.encodedLength(appId) + 8;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    buf.writeLong(requestId);
+    Encoders.Strings.encode(buf, appId);
+    buf.writeInt(shuffleId);
+    buf.writeInt(reduceId);
+  }
+
+  public static MergedBlockMetaRequest decode(ByteBuf buf) {
+    long requestId = buf.readLong();
+    String appId = Encoders.Strings.decode(buf);
+    int shuffleId = buf.readInt();
+    int reduceId = buf.readInt();
+    return new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hashCode(requestId, appId, shuffleId, reduceId);
+  }
+
+  @Override
+  public boolean equals(Object other) {
+    if (other instanceof MergedBlockMetaRequest) {
+      MergedBlockMetaRequest o = (MergedBlockMetaRequest) other;
+      return requestId == o.requestId && Objects.equal(appId, o.appId)
+        && shuffleId == o.shuffleId && reduceId == o.reduceId;

Review comment:
       nit: move the appId check to last.

##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A basic callback. This is extended by {@link RpcResponseCallback} and
+ * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
+ * can be handled in {@link TransportResponseHandler} a similar way.
+ *
+ * @since 3.2.0
+ */
+public interface BaseResponseCallback {

Review comment:
       nit: I dont have good suggestions, but any thoughts on renaming this interface better ?
   Thoughts @Ngone51, @attilapiros ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {

Review comment:
       Reviewer note: `Iterator` contract requires that `next` should check if `hasNext` is true - else throw `NoSuchElementException`.
   Unfortunately, the other iterators in `ExternalBlockHandler` are also not doing it ...

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));

Review comment:
       Add a one line note on what `blockIdParts[3]` can be.

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -124,11 +134,12 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r

Review comment:
       nit: `\\d+` instead ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -128,24 +134,23 @@ protected void handleMessage(
       BlockTransferMessage msgObj,
       TransportClient client,
       RpcResponseCallback callback) {
-    if (msgObj instanceof FetchShuffleBlocks || msgObj instanceof OpenBlocks) {
+    if (msgObj instanceof AbstractFetchShuffleBlocks || msgObj instanceof OpenBlocks) {
       final Timer.Context responseDelayContext = metrics.openBlockRequestLatencyMillis.time();
       try {
         int numBlockIds;
         long streamId;
-        if (msgObj instanceof FetchShuffleBlocks) {
-          FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj;
+        if (msgObj instanceof AbstractFetchShuffleBlocks) {
+          AbstractFetchShuffleBlocks msg = (AbstractFetchShuffleBlocks) msgObj;
           checkAuth(client, msg.appId);
-          numBlockIds = 0;
-          if (msg.batchFetchEnabled) {
-            numBlockIds = msg.mapIds.length;
+          numBlockIds = ((AbstractFetchShuffleBlocks) msgObj).getNumBlocks();

Review comment:
       `getNumBlocks` makes this code cleaner.

##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
##########
@@ -199,14 +200,31 @@ public void handle(ResponseMessage message) throws Exception {
       }
     } else if (message instanceof RpcFailure) {
       RpcFailure resp = (RpcFailure) message;
-      RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+      BaseResponseCallback listener = outstandingRpcs.get(resp.requestId);
       if (listener == null) {
         logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
           resp.requestId, getRemoteAddress(channel), resp.errorString);
       } else {
         outstandingRpcs.remove(resp.requestId);
         listener.onFailure(new RuntimeException(resp.errorString));
       }
+    } else if (message instanceof MergedBlockMetaSuccess) {
+      MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message;
+      MergedBlockMetaResponseCallback listener =
+        (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId);
+      if (listener == null) {
+        logger.warn(
+          "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not"
+            + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size());
+        resp.body().release();
+      } else {
+        outstandingRpcs.remove(resp.requestId);
+        try {
+          listener.onSuccess(resp.getNumChunks(), resp.body());
+        } finally {
+          resp.body().release();

Review comment:
       nit: move `resp.body().release()` to try/finally for this entire else block.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -189,9 +194,14 @@ protected void handleMessage(
     } else if (msgObj instanceof GetLocalDirsForExecutors) {
       GetLocalDirsForExecutors msg = (GetLocalDirsForExecutors) msgObj;
       checkAuth(client, msg.appId);
-      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, msg.execIds);
+      String[] execIdsForBlockResolver = Arrays.stream(msg.execIds)
+        .filter(execId -> !SHUFFLE_MERGER_IDENTIFIER.equals(execId)).toArray(String[]::new);
+      Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId,
+        execIdsForBlockResolver);
+      if (Arrays.asList(msg.execIds).contains(SHUFFLE_MERGER_IDENTIFIER)) {
+        localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
+      }

Review comment:
       ```suggestion
         Set<String> execIdsForBlockResolver = Sets.newHashSet(msg.execIds);
         boolean fetchMergedBlockDirs = execIdsForBlockResolver.remove(SHUFFLE_MERGER_IDENTIFIER);
         Map<String, String[]> localDirs = blockManager.getLocalDirs(msg.appId, execIdsForBlockResolver);
         if (fetchMergedBlockDirs) {
           localDirs.put(SHUFFLE_MERGER_IDENTIFIER, mergeManager.getMergedBlockDirs(msg.appId));
         }
   ```
   
   With a corresponding change in `blockManager.getLocalDirs` to take a set of executor ids instead of array.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);

Review comment:
       Just to clarify, we are not modifying old fetch protocol at all.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }

Review comment:
       nit:
   
   ```suggestion
         BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, id -> new BlocksInfo());
   ```
   
   and remove the get below

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -333,14 +382,18 @@ public ShuffleMetrics() {
       final int[] mapIdAndReduceIds = new int[2 * blockIds.length];
       for (int i = 0; i < blockIds.length; i++) {
         String[] blockIdParts = blockIds[i].split("_");
-        if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) {
+        if (blockIdParts.length != 4
+          || (!requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX))
+          || (requestForMergedBlockChunks && !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {
           throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]);
         }
         if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
           throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
             ", got:" + blockIds[i]);
         }
+        // For regular blocks this is mapId. For chunks this is reduceId.
         mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]);
+        // For regular blocks this is reduceId. For chunks this is chunkId.
         mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]);

Review comment:
       Do we want to rename this variable (here and in constructor) and this method given the overloading of map/reduce vs reduce/chunk now ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);

Review comment:
       Move this assertion into constructor itself.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {

Review comment:
       Here we are assuming all the blocks are either chunks or all are blocks.
   That is not the validation we are performing in `areShuffleBlocksOrChunks` - where a mix of both can pass.
   
   Do we want to make it stricter in `areShuffleBlocksOrChunks` ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -413,6 +466,47 @@ public ManagedBuffer next() {
     }
   }
 
+  private class ShuffleChunkManagedBufferIterator implements Iterator<ManagedBuffer> {
+
+    private int reduceIdx = 0;
+    private int chunkIdx = 0;
+
+    private final String appId;
+    private final int shuffleId;
+    private final int[] reduceIds;
+    private final int[][] chunkIds;
+
+    ShuffleChunkManagedBufferIterator(FetchShuffleBlockChunks msg) {
+      appId = msg.appId;
+      shuffleId = msg.shuffleId;
+      reduceIds = msg.reduceIds;
+      chunkIds = msg.chunkIds;
+    }
+
+    @Override
+    public boolean hasNext() {
+      // reduceIds.length must equal to chunkIds.length, and the passed in FetchShuffleBlockChunks
+      // must have non-empty reduceIds and chunkIds, see the checking logic in
+      // OneForOneBlockFetcher.
+      assert(reduceIds.length != 0 && reduceIds.length == chunkIds.length);
+      return reduceIdx < reduceIds.length && chunkIdx < chunkIds[reduceIdx].length;
+    }
+
+    @Override
+    public ManagedBuffer next() {
+      ManagedBuffer block = mergeManager.getMergedBlockData(
+        appId, shuffleId, reduceIds[reduceIdx], chunkIds[reduceIdx][chunkIdx]);
+      if (chunkIdx < chunkIds[reduceIdx].length - 1) {
+        chunkIdx += 1;
+      } else {
+        chunkIdx = 0;
+        reduceIdx += 1;
+      }
+      metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0);

Review comment:
       When would `block` be `null` ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       nit: `Longs.toArray` is a bit expensive - same for `Ints.toArray` below.
   If we can avoid it, while keeping code clean/concise, that would be preferable (there are couple of other locations in this PR which use these api's).

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {

Review comment:
       super nit: As coded, checking for `SHUFFLE_CHUNK_PREFIX` here is redundant - though I am fine with it for clarity.
   Btw, we are avoiding a '_' suffix check here.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       We can have `failRemainingBlocks` delegate to `failSingleBlockChunk` now ?
   ```
     private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
       Arrays.stream(failedBlockIds).forEach(blockId -> failSingleBlockChunk(blockId, e));
     }
   ```

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;

Review comment:
       ```suggestion
     return Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX) && !blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
   ```
   
   
   Review note: startsWith `SHUFFLE_BLOCK_PREFIX` is superset of startsWith `SHUFFLE_CHUNK_PREFIX` - though I am fine with keeping them separate in interest of clarity.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);

Review comment:
       Iterate over `primaryIdToBlocksInfo.entrySet` instead ?

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException {
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException {
       channel.close();

Review comment:
       What is the expected behavior if there are exceptions closing channel ? (the failure perhaps being due to `onData` throwing exception, for example)

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));

Review comment:
       Update the comment above/add a one line note on what `blockIdParts[4]` can be

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().

Review comment:
       How strong is this assumption ? Do we see a future evolution where this can break ? Or is it tied to the protocol in nontrivial ways ?
   As an example - `encode` and `decode` do not assume this currently (we could have avoided writing `chunkIdsLen` if they did)

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);

Review comment:
       ```suggestion
       for (String blockId : blocksInfoByPrimaryId.blockIds) {
           this.blockIds[blockIdIndex++] = blockId;
       }
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648537705



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -347,20 +361,48 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight
+    // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks.
+    // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order
+    // to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]()
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var mergedLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) {
+      if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) {
+        // These are push-based merged blocks or chunks of these merged blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          checkBlockSizes(blockInfos)

Review comment:
       For a merged blocks that are remote, this check is performed in `collectFetchRequests`. So, if we do it before the condition ` if (address.host == blockManager.blockManagerId.host)`, then it will be done twice for remote merged blocks.
   I think this is also a reason that currently `checkBlockSizes()`  is being called explicitly for each block type. For remote blocks (the last else block), the size is validated in `collectFetchRequests`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656706415



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       They can't be different. We have the assertion `assert (bitmaps.size() == numChunks)` in the `MergedBlockMeta.readChunkBitmaps()`. It comes here only when this assertion is true. Otherwise when it fails, then `PushMergedRemoteMetaFailedFetchResult` is posted which triggers the fallback.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r657613787



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,325 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a local push-merged block. false otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = {
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param numChunks number of chunks in the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,

Review comment:
       Removed passing numChunks in the iterator.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660751432



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -386,40 +415,53 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " +
-      s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
-      s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " +
+        s"the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration

Review comment:
       I have made this change but also add a var for counting num of hostLocalBlocks which is needed for the assertions. PTAL




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640322571



##########
File path: core/src/main/scala/org/apache/spark/MapOutputTracker.scala
##########
@@ -742,10 +742,10 @@ private[spark] class MapOutputTrackerMaster(
    *                  result.
    */
   def unregisterMergeResult(
-    shuffleId: Int,
-    reduceId: Int,
-    bmAddress: BlockManagerId,
-    mapId: Option[Int] = None) {
+      shuffleId: Int,
+      reduceId: Int,
+      bmAddress: BlockManagerId,
+      mapId: Option[Int] = None): Unit = {

Review comment:
       Note to reviewers: Same here. Was added as part of push-based shuffle.  There was a warning  so fixing it and indentation was off as well




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] AmplabJenkins removed a comment on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
AmplabJenkins removed a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870942257






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646726664



##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A basic callback. This is extended by {@link RpcResponseCallback} and
+ * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
+ * can be handled in {@link TransportResponseHandler} a similar way.
+ *
+ * @since 3.2.0
+ */
+public interface BaseResponseCallback {

Review comment:
       I am sort of ok with `BaseResponseCallback` ... was just not sure if there was a better way to name this.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646807905



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       This is invoked for each fetch request from the client. A fetch request is for multiple blocks. Since it is per fetch request, it is still frequent so I will create utils for it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645901796



##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -246,6 +304,14 @@ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) {
     }
   }
 
+  private void failSingleBlockChunk(String shuffleBlockChunkId, Throwable e) {
+    try {
+      listener.onBlockFetchFailure(shuffleBlockChunkId, e);
+    } catch (Exception e2) {
+      logger.error("Error from blockFetchFailure callback", e2);
+    }
+  }

Review comment:
       Reverted back this change.

##########
File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -276,9 +342,13 @@ public void onComplete(String streamId) throws IOException {
     @Override
     public void onFailure(String streamId, Throwable cause) throws IOException {
       channel.close();

Review comment:
       Reverted back this change.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r656795572



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -661,18 +745,21 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
-              shuffleMetrics.incLocalBlocksFetched(1)
-              shuffleMetrics.incLocalBytesRead(buf.size)
-            } else {
-              numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
-              shuffleMetrics.incRemoteBytesRead(buf.size)
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
-                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
-              }
-              shuffleMetrics.incRemoteBlocksFetched(1)
-              bytesInFlight -= size
-            }
+           if (hostLocalBlocks.contains(blockId -> mapIndex) ||
+             pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
+             // It is a host local block or a local shuffle chunk
+             shuffleMetrics.incLocalBlocksFetched(1)
+             shuffleMetrics.incLocalBytesRead(buf.size)
+           } else {
+             // Could be a remote shuffle chunk or remote block
+             numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+             shuffleMetrics.incRemoteBytesRead(buf.size)
+             if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+               shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+             }
+             shuffleMetrics.incRemoteBlocksFetched(1)
+             bytesInFlight -= size
+           }

Review comment:
       This is fixed now.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870235923


   Sorry for the delay. I'll do a review today. BTW, are there any other necessary mgnet PRs that have to be merged for the 3.2 release?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] mridulm commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
mridulm commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r654759215



##########
File path: core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
##########
@@ -0,0 +1,336 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based
+ * functionality to fetch merged block meta and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[this] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it
+   * can be modified by both the task thread and the netty thread.
+   */
+  private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote merged block.
+   */
+  def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = {
+    assert(isMergedShuffleBlockAddress(address))
+    address.host != blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address if of merged local block. false otherwise.
+   */
+  def isMergedLocal(address: BlockManagerId): Boolean = {
+    isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = {
+    chunksMetaMap.get(blockId).getCardinality
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle block chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked and the iterator
+   * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the merged block.
+   * @param numChunks number of chunks in the merged block.
+   * @param bitmaps   per chunk bitmap, where each bitmap contains all the mapIds that are merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+    bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / numChunks
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- 0 until numChunks) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch
+   *            metadata of merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " +
+              s"from ${req.address.host}:${req.address.port}", exception)
+            iterator.addToResultsQueue(
+              MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = {
+        logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It fetches all the
+   * outstanding merged local blocks.
+   * @param mergedLocalBlocks set of identified merged local blocks.
+   */
+  def fetchAllMergedLocalBlocks(
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (mergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local
+   * blocks.
+   */
+  private def fetchMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedMergerDirs.isDefined) {
+      logDebug(s"Fetching local merged blocks with cached executors dir: " +
+        s"${cachedMergerDirs.get.mkString(", ")}")
+      mergedLocalBlocks.foreach(blockId =>
+        fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId))
+    } else {
+      logDebug(s"Asynchronous fetching local merged blocks without cached executors dir")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          mergedLocalBlocks.takeWhile {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+          logDebug(s"Got local merged blocks (without cached executors' dir) in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms")
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for local merged blocks,
+          // we fallback to fetch the original unmerged blocks. We do not report block fetch
+          // failure.
+          logWarning(s"Error occurred while getting the local dirs for local merged " +
+            s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead",
+            throwable)
+          mergedLocalBlocks.foreach(
+            blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult(
+              blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false))
+          )
+      }
+    }
+  }
+
+  /**
+   * Fetch a single local merged block generated. This can also be executed by the task thread as
+   * well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the merged shuffle files are stored
+   * @param blockManagerId BlockManagerId
+   * @return Boolean represents successful or failed fetch
+   */
+  private[this] def fetchMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs)
+        .readChunkBitmaps()
+      // Fetch local merged shuffle block data as multiple chunks
+      val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs)
+      // Update total number of blocks to fetch, reflecting the multiple local chunks
+      iterator.incrementNumBlocksToFetch(bufs.size - 1)

Review comment:
       I was confused as well, until I read the comment explaining why `size - 1` (and I had to validate this was indeed valid while reviewing) :-)
   Explicitly decrementing and then incrementing by size will make it clear given the cost is negligible.
   I like the proposal @Ngone51.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] SparkQA commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
SparkQA commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870826812


   **[Test build #140388 has started](https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/140388/testReport)** for PR 32140 at commit [`ad89a02`](https://github.com/apache/spark/commit/ad89a0208a5e3f880fca502c297362388a104dd7).


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648594657



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       I will work on this. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] Ngone51 edited a comment on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
Ngone51 edited a comment on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870235923






-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r640312469



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -108,13 +115,6 @@ final class ShuffleBlockFetcherIterator(
 
   private[this] val startTimeNs = System.nanoTime()
 
-  /** Local blocks to fetch, excluding zero-sized blocks. */
-  private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
-
-  /** Host local blockIds to fetch by executors, excluding zero-sized blocks. */
-  private[this] val hostLocalBlocksByExecutor =
-    LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
-

Review comment:
       Note to reviewers: Both of these are created locally in the `initialize` and `fetchFallbackBlocks`. With push-based shuffle, `partitionBlocksByFetchMode` can be called multiple times because failure to fetch merged shuffle blocks/chunks (fallback) finds original blocks that made up the failed merged block/chunk and then we need to create new FetchRequests from these. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r660771829



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -767,6 +878,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the local push-merged meta. In this case, the blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the local push-merged directories from the ESS. In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) =>
+            // Fetch local push-merged shuffle block data as multiple shuffle chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading local push-merged data, we fallback to

Review comment:
       Right, it doesn't read data file. Just creates ManagedBuffers. Will change the comment




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r648683013



##########
File path: core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -1124,4 +1392,298 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult
+
+  /**
+   * Result of a fetch from a remote merged block unsuccessfully.
+   * Instead of treating this as a FailureFetchResult, we ignore this failure
+   * and fallback to fetch the original unmerged blocks.
+   * @param blockId block id
+   * @param address BlockManager that the merged block was attempted to be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class IgnoreFetchResult(blockId: BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a merged block.
+   *
+   * @param shuffleId        shuffle id.
+   * @param reduceId         reduce id.
+   * @param blockSize        size of each merged block.
+   * @param numChunks        number of chunks in the merged block.
+   * @param bitmaps          bitmaps for every chunk.
+   * @param address          BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param address   BlockManager that the merged status was fetched from.
+   */
+  private[storage] case class MergedBlocksMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId,
+      blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult
+}
+
+/**
+ * Helper class that encapsulates all the push-based functionality to fetch merged block meta
+ * and merged shuffle block chunks.
+ */
+private class PushBasedFetchHelper(

Review comment:
       I have moved `PushBasedFetchHelper` into its own file. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-859181282


   This is still dependent on the changes in https://github.com/apache/spark/pull/32140 which has the protocol side of changes.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on a change in pull request #32140: [WIP][SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r645907997



##########
File path: common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * A basic callback. This is extended by {@link RpcResponseCallback} and
+ * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests
+ * can be handled in {@link TransportResponseHandler} a similar way.
+ *
+ * @since 3.2.0
+ */
+public interface BaseResponseCallback {

Review comment:
       Since this just has `onFailure` method, should I rename it to `FailureCallback`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


[GitHub] [spark] otterc commented on pull request #32140: [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and remote merged shuffle data

Posted by GitBox <gi...@apache.org>.
otterc commented on pull request #32140:
URL: https://github.com/apache/spark/pull/32140#issuecomment-870994802


   Thanks @mridulm and @Ngone51 for the thorough reviews!


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org