You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by zs...@apache.org on 2016/12/09 23:44:25 UTC

spark git commit: [SPARK-4105] retry the fetch or stage if shuffle block is corrupt

Repository: spark
Updated Branches:
  refs/heads/master d60ab5fd9 -> cf33a8628


[SPARK-4105] retry the fetch or stage if shuffle block is corrupt

## What changes were proposed in this pull request?

There is an outstanding issue that existed for a long time: Sometimes the shuffle blocks are corrupt and can't be decompressed. We recently hit this in three different workloads, sometimes we can reproduce it by every try, sometimes can't. I also found that when the corruption happened, the beginning and end of the blocks are correct, the corruption happen in the middle. There was one case that the string of block id is corrupt by one character. It seems that it's very likely the corruption is introduced by some weird machine/hardware, also the checksum (16 bits) in TCP is not strong enough to identify all the corruption.

Unfortunately, Spark does not have checksum for shuffle blocks or broadcast, the job will fail if any corruption happen in the shuffle block from disk, or broadcast blocks during network. This PR try to detect the corruption after fetching shuffle blocks by decompressing them, because most of the compression already have checksum in them. It will retry the block, or failed with FetchFailure, so the previous stage could be retried on different (still random) machines.

Checksum for broadcast will be added by another PR.

## How was this patch tested?

Added unit tests

Author: Davies Liu <da...@databricks.com>

Closes #15923 from davies/detect_corrupt.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/cf33a862
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/cf33a862
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/cf33a862

Branch: refs/heads/master
Commit: cf33a86285629abe72c1acf235b8bfa6057220a8
Parents: d60ab5f
Author: Davies Liu <da...@databricks.com>
Authored: Fri Dec 9 15:44:22 2016 -0800
Committer: Shixiong Zhu <sh...@databricks.com>
Committed: Fri Dec 9 15:44:22 2016 -0800

----------------------------------------------------------------------
 .../spark/shuffle/BlockStoreShuffleReader.scala |  13 +-
 .../storage/ShuffleBlockFetcherIterator.scala   | 133 +++++++++-----
 .../spark/util/io/ChunkedByteBuffer.scala       |   2 +-
 .../ShuffleBlockFetcherIteratorSuite.scala      | 172 ++++++++++++++++++-
 4 files changed, 263 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/cf33a862/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index b9d8349..8b2e26c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -42,24 +42,21 @@ private[spark] class BlockStoreShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
-    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+    val wrappedStreams = new ShuffleBlockFetcherIterator(
       context,
       blockManager.shuffleClient,
       blockManager,
       mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition),
+      serializerManager.wrapStream,
       // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
       SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024,
-      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue))
-
-    // Wrap the streams for compression and encryption based on configuration
-    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
-      serializerManager.wrapStream(blockId, inputStream)
-    }
+      SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
+      SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
 
     val serializerInstance = dep.serializer.newInstance()
 
     // Create a key/value iterator for each stream
-    val recordIter = wrappedStreams.flatMap { wrappedStream =>
+    val recordIter = wrappedStreams.flatMap { case (blockId, wrappedStream) =>
       // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
       // NextIterator. The NextIterator makes sure that close() is called on the
       // underlying InputStream when all records have been read.

http://git-wip-us.apache.org/repos/asf/spark/blob/cf33a862/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 269c12d..b720aae 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,19 +17,21 @@
 
 package org.apache.spark.storage
 
-import java.io.InputStream
+import java.io.{InputStream, IOException}
+import java.nio.ByteBuffer
 import java.util.concurrent.LinkedBlockingQueue
 import javax.annotation.concurrent.GuardedBy
 
+import scala.collection.mutable
 import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.control.NonFatal
 
 import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.internal.Logging
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
 import org.apache.spark.shuffle.FetchFailedException
 import org.apache.spark.util.Utils
+import org.apache.spark.util.io.ChunkedByteBufferOutputStream
 
 /**
  * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -47,8 +49,10 @@ import org.apache.spark.util.Utils
  * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
  *                        For each block we also require the size (in bytes as a long field) in
  *                        order to throttle the memory usage.
+ * @param streamWrapper A function to wrap the returned input stream.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
  * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
+ * @param detectCorrupt whether to detect any corruption in fetched blocks.
  */
 private[spark]
 final class ShuffleBlockFetcherIterator(
@@ -56,8 +60,10 @@ final class ShuffleBlockFetcherIterator(
     shuffleClient: ShuffleClient,
     blockManager: BlockManager,
     blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
+    streamWrapper: (BlockId, InputStream) => InputStream,
     maxBytesInFlight: Long,
-    maxReqsInFlight: Int)
+    maxReqsInFlight: Int,
+    detectCorrupt: Boolean)
   extends Iterator[(BlockId, InputStream)] with Logging {
 
   import ShuffleBlockFetcherIterator._
@@ -94,7 +100,7 @@ final class ShuffleBlockFetcherIterator(
    * Current [[FetchResult]] being processed. We track this so we can release the current buffer
    * in case of a runtime exception when processing the current buffer.
    */
-  @volatile private[this] var currentResult: FetchResult = null
+  @volatile private[this] var currentResult: SuccessFetchResult = null
 
   /**
    * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -108,6 +114,12 @@ final class ShuffleBlockFetcherIterator(
   /** Current number of requests in flight */
   private[this] var reqsInFlight = 0
 
+  /**
+   * The blocks that can't be decompressed successfully, it is used to guarantee that we retry
+   * at most once for those corrupted blocks.
+   */
+  private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
   private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()
 
   /**
@@ -123,9 +135,8 @@ final class ShuffleBlockFetcherIterator(
   // The currentResult is set to null to prevent releasing the buffer again on cleanup()
   private[storage] def releaseCurrentResultBuffer(): Unit = {
     // Release the current buffer if necessary
-    currentResult match {
-      case SuccessFetchResult(_, _, _, buf, _) => buf.release()
-      case _ =>
+    if (currentResult != null) {
+      currentResult.buf.release()
     }
     currentResult = null
   }
@@ -305,40 +316,84 @@ final class ShuffleBlockFetcherIterator(
    */
   override def next(): (BlockId, InputStream) = {
     numBlocksProcessed += 1
-    val startFetchWait = System.currentTimeMillis()
-    currentResult = results.take()
-    val result = currentResult
-    val stopFetchWait = System.currentTimeMillis()
-    shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
-
-    result match {
-      case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) =>
-        if (address != blockManager.blockManagerId) {
-          shuffleMetrics.incRemoteBytesRead(buf.size)
-          shuffleMetrics.incRemoteBlocksFetched(1)
-        }
-        bytesInFlight -= size
-        if (isNetworkReqDone) {
-          reqsInFlight -= 1
-          logDebug("Number of requests in flight " + reqsInFlight)
-        }
-      case _ =>
-    }
-    // Send fetch requests up to maxBytesInFlight
-    fetchUpToMaxBytes()
 
-    result match {
-      case FailureFetchResult(blockId, address, e) =>
-        throwFetchFailedException(blockId, address, e)
+    var result: FetchResult = null
+    var input: InputStream = null
+    // Take the next fetched result and try to decompress it to detect data corruption,
+    // then fetch it one more time if it's corrupt, throw FailureFetchResult if the second fetch
+    // is also corrupt, so the previous stage could be retried.
+    // For local shuffle block, throw FailureFetchResult for the first IOException.
+    while (result == null) {
+      val startFetchWait = System.currentTimeMillis()
+      result = results.take()
+      val stopFetchWait = System.currentTimeMillis()
+      shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
 
-      case SuccessFetchResult(blockId, address, _, buf, _) =>
-        try {
-          (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this))
-        } catch {
-          case NonFatal(t) =>
-            throwFetchFailedException(blockId, address, t)
-        }
+      result match {
+        case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
+          if (address != blockManager.blockManagerId) {
+            shuffleMetrics.incRemoteBytesRead(buf.size)
+            shuffleMetrics.incRemoteBlocksFetched(1)
+          }
+          bytesInFlight -= size
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+
+          val in = try {
+            buf.createInputStream()
+          } catch {
+            // The exception could only be throwed by local shuffle block
+            case e: IOException =>
+              assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+              logError("Failed to create input stream from local block", e)
+              buf.release()
+              throwFetchFailedException(blockId, address, e)
+          }
+
+          input = streamWrapper(blockId, in)
+          // Only copy the stream if it's wrapped by compression or encryption, also the size of
+          // block is small (the decompressed block is smaller than maxBytesInFlight)
+          if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) {
+            val originalInput = input
+            val out = new ChunkedByteBufferOutputStream(64 * 1024, ByteBuffer.allocate)
+            try {
+              // Decompress the whole block at once to detect any corruption, which could increase
+              // the memory usage tne potential increase the chance of OOM.
+              // TODO: manage the memory used here, and spill it into disk in case of OOM.
+              Utils.copyStream(input, out)
+              out.close()
+              input = out.toChunkedByteBuffer.toInputStream(dispose = true)
+            } catch {
+              case e: IOException =>
+                buf.release()
+                if (buf.isInstanceOf[FileSegmentManagedBuffer]
+                  || corruptedBlocks.contains(blockId)) {
+                  throwFetchFailedException(blockId, address, e)
+                } else {
+                  logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
+                  corruptedBlocks += blockId
+                  fetchRequests += FetchRequest(address, Array((blockId, size)))
+                  result = null
+                }
+            } finally {
+              // TODO: release the buf here to free memory earlier
+              originalInput.close()
+              in.close()
+            }
+          }
+
+        case FailureFetchResult(blockId, address, e) =>
+          throwFetchFailedException(blockId, address, e)
+      }
+
+      // Send fetch requests up to maxBytesInFlight
+      fetchUpToMaxBytes()
     }
+
+    currentResult = result.asInstanceOf[SuccessFetchResult]
+    (currentResult.blockId, new BufferReleasingInputStream(input, this))
   }
 
   private def fetchUpToMaxBytes(): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/cf33a862/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index da08661..7572cac 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -151,7 +151,7 @@ private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
  * @param dispose if true, `ChunkedByteBuffer.dispose()` will be called at the end of the stream
  *                in order to close any memory-mapped files which back the buffer.
  */
-private class ChunkedByteBufferInputStream(
+private[spark] class ChunkedByteBufferInputStream(
     var chunkedByteBuffer: ChunkedByteBuffer,
     dispose: Boolean)
   extends InputStream {

http://git-wip-us.apache.org/repos/asf/spark/blob/cf33a862/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index e3ec996..e56e440 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.storage
 
-import java.io.InputStream
+import java.io.{File, InputStream, IOException}
 import java.util.concurrent.Semaphore
 
 import scala.concurrent.ExecutionContext.Implicits.global
@@ -31,8 +31,9 @@ import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.{SparkFunSuite, TaskContext}
 import org.apache.spark.network._
-import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
 import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.FetchFailedException
 
 
@@ -63,7 +64,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
   // Create a mock managed buffer for testing
   def createMockManagedBuffer(): ManagedBuffer = {
     val mockManagedBuffer = mock(classOf[ManagedBuffer])
-    when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
+    val in = mock(classOf[InputStream])
+    when(in.read(any())).thenReturn(1)
+    when(in.read(any(), any(), any())).thenReturn(1)
+    when(mockManagedBuffer.createInputStream()).thenReturn(in)
     mockManagedBuffer
   }
 
@@ -99,8 +103,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     // 3 local blocks fetched in initialization
     verify(blockManager, times(3)).getBlockData(any())
@@ -172,8 +178,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
     iterator.next()._2.close() // close() first block's input stream
@@ -201,9 +209,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
     )
 
     // Semaphore to coordinate event sequence in two different threads.
@@ -235,8 +243,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
       transfer,
       blockManager,
       blocksByAddress,
+      (_, in) => in,
       48 * 1024 * 1024,
-      Int.MaxValue)
+      Int.MaxValue,
+      true)
 
     // Continue only after the mock calls onBlockFetchFailure
     sem.acquire()
@@ -247,4 +257,148 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     intercept[FetchFailedException] { iterator.next() }
     intercept[FetchFailedException] { iterator.next() }
   }
+
+  test("retry corrupt blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-client", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+    )
+
+    // Semaphore to coordinate event sequence in two different threads.
+    val sem = new Semaphore(0)
+
+    val corruptStream = mock(classOf[InputStream])
+    when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+    val corruptBuffer = mock(classOf[ManagedBuffer])
+    when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+    val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100)
+
+    val transfer = mock(classOf[BlockTransferService])
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+    val taskContext = TaskContext.empty()
+    val iterator = new ShuffleBlockFetcherIterator(
+      taskContext,
+      transfer,
+      blockManager,
+      blocksByAddress,
+      (_, in) => new LimitedInputStream(in, 100),
+      48 * 1024 * 1024,
+      Int.MaxValue,
+      true)
+
+    // Continue only after the mock calls onBlockFetchFailure
+    sem.acquire()
+
+    // The first block should be returned without an exception
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 0))
+
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    // The next block is corrupt local block (the second one is corrupt and retried)
+    intercept[FetchFailedException] { iterator.next() }
+
+    sem.acquire()
+    intercept[FetchFailedException] { iterator.next() }
+  }
+
+  test("retry corrupt blocks (disabled)") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-client", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    // Make sure remote blocks would return
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val blocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()
+    )
+
+    // Semaphore to coordinate event sequence in two different threads.
+    val sem = new Semaphore(0)
+
+    val corruptStream = mock(classOf[InputStream])
+    when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt"))
+    val corruptBuffer = mock(classOf[ManagedBuffer])
+    when(corruptBuffer.createInputStream()).thenReturn(corruptStream)
+
+    val transfer = mock(classOf[BlockTransferService])
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+      override def answer(invocation: InvocationOnMock): Unit = {
+        val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          // Return the first block, and then fail.
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 1, 0).toString, corruptBuffer)
+          listener.onBlockFetchSuccess(
+            ShuffleBlockId(0, 2, 0).toString, corruptBuffer)
+          sem.release()
+        }
+      }
+    })
+
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+      (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+
+    val taskContext = TaskContext.empty()
+    val iterator = new ShuffleBlockFetcherIterator(
+      taskContext,
+      transfer,
+      blockManager,
+      blocksByAddress,
+      (_, in) => new LimitedInputStream(in, 100),
+      48 * 1024 * 1024,
+      Int.MaxValue,
+      false)
+
+    // Continue only after the mock calls onBlockFetchFailure
+    sem.acquire()
+
+    // The first block should be returned without an exception
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 0))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 1, 0))
+    val (id3, _) = iterator.next()
+    assert(id3 === ShuffleBlockId(0, 2, 0))
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org