You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ka...@apache.org on 2015/06/25 07:10:36 UTC

spark git commit: [SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader

Repository: spark
Updated Branches:
  refs/heads/master 82f80c1c7 -> 7bac2fe77


[SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader

This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process.

The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, InputStream). Deserialization of records is now handled in the ShuffleReader.read() method.

This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are retrieved.

Author: Matt Massie <ma...@cs.berkeley.edu>
Author: Kay Ousterhout <ka...@gmail.com>

Closes #6423 from massie/shuffle-api-cleanup and squashes the following commits:

8b0632c [Matt Massie] Minor Scala style fixes
d0a1b39 [Matt Massie] Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup
290f1eb [Kay Ousterhout] Added test for HashShuffleReader.read()
5186da0 [Kay Ousterhout] Revert "Add test to ensure HashShuffleReader is freeing resources"
f98a1b9 [Matt Massie] Add test to ensure HashShuffleReader is freeing resources
a011bfa [Matt Massie] Use PrivateMethodTester on check that delegate stream is closed
4ea1712 [Matt Massie] Small code cleanup for readability
7429a98 [Matt Massie] Update tests to check that BufferReleasingStream is closing delegate InputStream
f458489 [Matt Massie] Remove unnecessary map() on return Iterator
4abb855 [Matt Massie] Consolidate metric code. Make it clear why InterrubtibleIterator is needed.
5c30405 [Matt Massie] Return visibility of BlockStoreShuffleFetcher to private[hash]
7eedd1d [Matt Massie] Small Scala import cleanup
28f8085 [Matt Massie] Small import nit
f93841e [Matt Massie] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher.
7e8e0fe [Matt Massie] Minor Scala style fixes
01e8721 [Matt Massie] Explicitly cast iterator in branches for type clarity
7c8f73e [Matt Massie] Close Block InputStream immediately after all records are read
208b7a5 [Matt Massie] Small code style changes
b70c945 [Matt Massie] Make BlockStoreShuffleFetcher visible to shuffle package
19135f2 [Matt Massie] [SPARK-7884] Allow Spark shuffle APIs to be more customizable


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7bac2fe7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7bac2fe7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7bac2fe7

Branch: refs/heads/master
Commit: 7bac2fe7717c0102b4875dbd95ae0bbf964536e3
Parents: 82f80c1
Author: Matt Massie <ma...@cs.berkeley.edu>
Authored: Wed Jun 24 22:09:31 2015 -0700
Committer: Kay Ousterhout <ka...@gmail.com>
Committed: Wed Jun 24 22:10:06 2015 -0700

----------------------------------------------------------------------
 .../shuffle/hash/BlockStoreShuffleFetcher.scala |  59 +++-----
 .../spark/shuffle/hash/HashShuffleReader.scala  |  52 ++++++-
 .../storage/ShuffleBlockFetcherIterator.scala   |  90 +++++++----
 .../shuffle/hash/HashShuffleReaderSuite.scala   | 150 +++++++++++++++++++
 .../ShuffleBlockFetcherIteratorSuite.scala      |  59 +++++---
 5 files changed, 314 insertions(+), 96 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7bac2fe7/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 597d46a..9d8e7e9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -17,29 +17,29 @@
 
 package org.apache.spark.shuffle.hash
 
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashMap
-import scala.util.{Failure, Success, Try}
+import java.io.InputStream
+
+import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.util.{Failure, Success}
 
 import org.apache.spark._
-import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.FetchFailedException
-import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
+  ShuffleBlockId}
 
 private[hash] object BlockStoreShuffleFetcher extends Logging {
-  def fetch[T](
+  def fetchBlockStreams(
       shuffleId: Int,
       reduceId: Int,
       context: TaskContext,
-      serializer: Serializer)
-    : Iterator[T] =
+      blockManager: BlockManager,
+      mapOutputTracker: MapOutputTracker)
+    : Iterator[(BlockId, InputStream)] =
   {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
-    val blockManager = SparkEnv.get.blockManager
 
     val startTime = System.currentTimeMillis
-    val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
+    val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
     logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
       shuffleId, reduceId, System.currentTimeMillis - startTime))
 
@@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
         (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
     }
 
-    def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
+    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+      context,
+      blockManager.shuffleClient,
+      blockManager,
+      blocksByAddress,
+      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+
+    // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
+    blockFetcherItr.map { blockPair =>
       val blockId = blockPair._1
       val blockOption = blockPair._2
       blockOption match {
-        case Success(block) => {
-          block.asInstanceOf[Iterator[T]]
+        case Success(inputStream) => {
+          (blockId, inputStream)
         }
         case Failure(e) => {
           blockId match {
@@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
         }
       }
     }
-
-    val blockFetcherItr = new ShuffleBlockFetcherIterator(
-      context,
-      SparkEnv.get.blockManager.shuffleClient,
-      blockManager,
-      blocksByAddress,
-      serializer,
-      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
-      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-    val itr = blockFetcherItr.flatMap(unpackBlock)
-
-    val completionIter = CompletionIterator[T, Iterator[T]](itr, {
-      context.taskMetrics.updateShuffleReadMetrics()
-    })
-
-    new InterruptibleIterator[T](context, completionIter) {
-      val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-      override def next(): T = {
-        readMetrics.incRecordsRead(1)
-        delegate.next()
-      }
-    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/7bac2fe7/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 41bafab..d5c9880 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,16 +17,20 @@
 
 package org.apache.spark.shuffle.hash
 
-import org.apache.spark.{InterruptibleIterator, TaskContext}
+import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.storage.BlockManager
+import org.apache.spark.util.CompletionIterator
 import org.apache.spark.util.collection.ExternalSorter
 
 private[spark] class HashShuffleReader[K, C](
     handle: BaseShuffleHandle[K, _, C],
     startPartition: Int,
     endPartition: Int,
-    context: TaskContext)
+    context: TaskContext,
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
   extends ShuffleReader[K, C]
 {
   require(endPartition == startPartition + 1,
@@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
+    val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
+      handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
+
+    // Wrap the streams for compression based on configuration
+    val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
+      blockManager.wrapForCompression(blockId, inputStream)
+    }
+
     val ser = Serializer.getSerializer(dep.serializer)
-    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
+    val serializerInstance = ser.newInstance()
+
+    // Create a key/value iterator for each stream
+    val recordIter = wrappedStreams.flatMap { wrappedStream =>
+      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+      // NextIterator. The NextIterator makes sure that close() is called on the
+      // underlying InputStream when all records have been read.
+      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+    }
+
+    // Update the context task metrics for each record read.
+    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+      recordIter.map(record => {
+        readMetrics.incRecordsRead(1)
+        record
+      }),
+      context.taskMetrics().updateShuffleReadMetrics())
+
+    // An interruptible iterator must be used here in order to support task cancellation
+    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
 
     val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {
-        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
+        // We are reading values that are already combined
+        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
       } else {
-        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
+        // We don't know the value type, but also don't care -- the dependency *should*
+        // have made sure its compatible w/ this aggregator, which will convert the value
+        // type to the combined type C
+        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
       }
     } else {
       require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
-
-      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
-      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
+      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
     }
 
     // Sort the output if there is a sort ordering defined.

http://git-wip-us.apache.org/repos/asf/spark/blob/7bac2fe7/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index d0faab6..e49e396 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,23 +17,23 @@
 
 package org.apache.spark.storage
 
+import java.io.InputStream
 import java.util.concurrent.LinkedBlockingQueue
 
 import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
 import scala.util.{Failure, Try}
 
 import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
 import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
-import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.util.Utils
 
 /**
  * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
  * manager. For remote blocks, it fetches them using the provided BlockTransferService.
  *
- * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
- * pipelined fashion as they are received.
+ * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
+ * in a pipelined fashion as they are received.
  *
  * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
  * using too much memory.
@@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
  * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
  *                        For each block we also require the size (in bytes as a long field) in
  *                        order to throttle the memory usage.
- * @param serializer serializer used to deserialize the data.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
  */
 private[spark]
@@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
     shuffleClient: ShuffleClient,
     blockManager: BlockManager,
     blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-    serializer: Serializer,
     maxBytesInFlight: Long)
-  extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
+  extends Iterator[(BlockId, Try[InputStream])] with Logging {
 
   import ShuffleBlockFetcherIterator._
 
@@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(
 
   /**
    * A queue to hold our results. This turns the asynchronous model provided by
-   * [[BlockTransferService]] into a synchronous model (iterator).
+   * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
    */
   private[this] val results = new LinkedBlockingQueue[FetchResult]
 
@@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
   /** Current bytes in flight from our requests */
   private[this] var bytesInFlight = 0L
 
-  private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-
-  private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
+  private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
 
   /**
    * Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(
 
   initialize()
 
-  /**
-   * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
-   */
-  private[this] def cleanup() {
-    isZombie = true
+  // Decrements the buffer reference count.
+  // The currentResult is set to null to prevent releasing the buffer again on cleanup()
+  private[storage] def releaseCurrentResultBuffer(): Unit = {
     // Release the current buffer if necessary
     currentResult match {
       case SuccessFetchResult(_, _, buf) => buf.release()
       case _ =>
     }
+    currentResult = null
+  }
 
+  /**
+   * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+   */
+  private[this] def cleanup() {
+    isZombie = true
+    releaseCurrentResultBuffer()
     // Release buffers in the results queue
     val iter = results.iterator()
     while (iter.hasNext) {
@@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(
 
   override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
 
-  override def next(): (BlockId, Try[Iterator[Any]]) = {
+  /**
+   * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
+   * underlying each InputStream will be freed by the cleanup() method registered with the
+   * TaskCompletionListener. However, callers should close() these InputStreams
+   * as soon as they are no longer needed, in order to release memory as early as possible.
+   */
+  override def next(): (BlockId, Try[InputStream]) = {
     numBlocksProcessed += 1
     val startFetchWait = System.currentTimeMillis()
     currentResult = results.take()
@@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator(
       sendRequest(fetchRequests.dequeue())
     }
 
-    val iteratorTry: Try[Iterator[Any]] = result match {
+    val iteratorTry: Try[InputStream] = result match {
       case FailureFetchResult(_, e) =>
         Failure(e)
       case SuccessFetchResult(blockId, _, buf) =>
         // There is a chance that createInputStream can fail (e.g. fetching a local file that does
         // not exist, SPARK-4085). In that case, we should propagate the right exception so
         // the scheduler gets a FetchFailedException.
-        Try(buf.createInputStream()).map { is0 =>
-          val is = blockManager.wrapForCompression(blockId, is0)
-          val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
-          CompletionIterator[Any, Iterator[Any]](iter, {
-            // Once the iterator is exhausted, release the buffer and set currentResult to null
-            // so we don't release it again in cleanup.
-            currentResult = null
-            buf.release()
-          })
+        Try(buf.createInputStream()).map { inputStream =>
+          new BufferReleasingInputStream(inputStream, this)
         }
     }
 
@@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator(
   }
 }
 
+/**
+ * Helper class that ensures a ManagedBuffer is release upon InputStream.close()
+ */
+private class BufferReleasingInputStream(
+    private val delegate: InputStream,
+    private val iterator: ShuffleBlockFetcherIterator)
+  extends InputStream {
+  private[this] var closed = false
+
+  override def read(): Int = delegate.read()
+
+  override def close(): Unit = {
+    if (!closed) {
+      delegate.close()
+      iterator.releaseCurrentResultBuffer()
+      closed = true
+    }
+  }
+
+  override def available(): Int = delegate.available()
+
+  override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
+
+  override def skip(n: Long): Long = delegate.skip(n)
+
+  override def markSupported(): Boolean = delegate.markSupported()
+
+  override def read(b: Array[Byte]): Int = delegate.read(b)
+
+  override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
+
+  override def reset(): Unit = delegate.reset()
+}
 
 private[storage]
 object ShuffleBlockFetcherIterator {

http://git-wip-us.apache.org/repos/asf/spark/blob/7bac2fe7/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
new file mode 100644
index 0000000..28ca686
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import org.mockito.Matchers.{eq => meq, _}
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.shuffle.BaseShuffleHandle
+import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
+
+/**
+ * Wrapper for a managed buffer that keeps track of how many times retain and release are called.
+ *
+ * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
+ * is final (final classes cannot be spied on).
+ */
+class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
+  var callsToRetain = 0
+  var callsToRelease = 0
+
+  override def size(): Long = underlyingBuffer.size()
+  override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
+  override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
+  override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
+
+  override def retain(): ManagedBuffer = {
+    callsToRetain += 1
+    underlyingBuffer.retain()
+  }
+  override def release(): ManagedBuffer = {
+    callsToRelease += 1
+    underlyingBuffer.release()
+  }
+}
+
+class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
+
+  /**
+   * This test makes sure that, when data is read from a HashShuffleReader, the underlying
+   * ManagedBuffers that contain the data are eventually released.
+   */
+  test("read() releases resources on completion") {
+    val testConf = new SparkConf(false)
+    // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
+    // shuffle code calls SparkEnv.get()).
+    sc = new SparkContext("local", "test", testConf)
+
+    val reduceId = 15
+    val shuffleId = 22
+    val numMaps = 6
+    val keyValuePairsPerMap = 10
+    val serializer = new JavaSerializer(testConf)
+
+    // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
+    // can ensure retain() and release() are properly called.
+    val blockManager = mock(classOf[BlockManager])
+
+    // Create a return function to use for the mocked wrapForCompression method that just returns
+    // the original input stream.
+    val dummyCompressionFunction = new Answer[InputStream] {
+      override def answer(invocation: InvocationOnMock): InputStream =
+        invocation.getArguments()(1).asInstanceOf[InputStream]
+    }
+
+    // Create a buffer with some randomly generated key-value pairs to use as the shuffle data
+    // from each mappers (all mappers return the same shuffle data).
+    val byteOutputStream = new ByteArrayOutputStream()
+    val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
+    (0 until keyValuePairsPerMap).foreach { i =>
+      serializationStream.writeKey(i)
+      serializationStream.writeValue(2*i)
+    }
+
+    // Setup the mocked BlockManager to return RecordingManagedBuffers.
+    val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
+    when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
+    val buffers = (0 until numMaps).map { mapId =>
+      // Create a ManagedBuffer with the shuffle data.
+      val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
+      val managedBuffer = new RecordingManagedBuffer(nioBuffer)
+
+      // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
+      // fetch shuffle data.
+      val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+      when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
+      when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
+        .thenAnswer(dummyCompressionFunction)
+
+      managedBuffer
+    }
+
+    // Make a mocked MapOutputTracker for the shuffle reader to use to determine what
+    // shuffle data to read.
+    val mapOutputTracker = mock(classOf[MapOutputTracker])
+    // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
+    // for the code to read data over the network.
+    val statuses: Array[(BlockManagerId, Long)] =
+      Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong))
+    when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
+
+    // Create a mocked shuffle handle to pass into HashShuffleReader.
+    val shuffleHandle = {
+      val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
+      when(dependency.serializer).thenReturn(Some(serializer))
+      when(dependency.aggregator).thenReturn(None)
+      when(dependency.keyOrdering).thenReturn(None)
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+
+    val shuffleReader = new HashShuffleReader(
+      shuffleHandle,
+      reduceId,
+      reduceId + 1,
+      new TaskContextImpl(0, 0, 0, 0, null),
+      blockManager,
+      mapOutputTracker)
+
+    assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
+
+    // Calling .length above will have exhausted the iterator; make sure that exhausting the
+    // iterator caused retain and release to be called on each buffer.
+    buffers.foreach { buffer =>
+      assert(buffer.callsToRetain === 1)
+      assert(buffer.callsToRelease === 1)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7bac2fe7/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 2a7fe67..9ced414 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,23 +17,25 @@
 
 package org.apache.spark.storage
 
+import java.io.InputStream
 import java.util.concurrent.Semaphore
 
-import scala.concurrent.future
 import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent.future
 
 import org.mockito.Matchers.{any, eq => meq}
 import org.mockito.Mockito._
 import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl}
+import org.apache.spark.{SparkFunSuite, TaskContextImpl}
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.shuffle.BlockFetchingListener
-import org.apache.spark.serializer.TestSerializer
 
-class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
+
+class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
   // Some of the tests are quite tricky because we are testing the cleanup behavior
   // in the presence of faults.
 
@@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     transfer
   }
 
-  private val conf = new SparkConf
+  // Create a mock managed buffer for testing
+  def createMockManagedBuffer(): ManagedBuffer = {
+    val mockManagedBuffer = mock(classOf[ManagedBuffer])
+    when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
+    mockManagedBuffer
+  }
 
   test("successful 3 local reads + 2 remote reads") {
     val blockManager = mock(classOf[BlockManager])
@@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
 
     // Make sure blockManager.getBlockData would return the blocks
     val localBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
     localBlocks.foreach { case (blockId, buf) =>
       doReturn(buf).when(blockManager).getBlockData(meq(blockId))
     }
@@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val remoteBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
-    )
+      ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())
 
     val transfer = createMockTransfer(remoteBlocks)
 
@@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
     // 3 local blocks fetched in initialization
@@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
 
     for (i <- 0 until 5) {
       assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
-      val (blockId, subIterator) = iterator.next()
-      assert(subIterator.isSuccess,
+      val (blockId, inputStream) = iterator.next()
+      assert(inputStream.isSuccess,
         s"iterator should have 5 elements defined but actually has $i elements")
 
-      // Make sure we release the buffer once the iterator is exhausted.
+      // Make sure we release buffers when a wrapped input stream is closed.
       val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+      // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
+      val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream]
       verify(mockBuf, times(0)).release()
-      subIterator.get.foreach(_ => Unit)  // exhaust the iterator
+      val delegateAccess = PrivateMethod[InputStream]('delegate)
+
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close()
+      wrappedInputStream.close()
+      verify(mockBuf, times(1)).release()
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
+      wrappedInputStream.close() // close should be idempotent
       verify(mockBuf, times(1)).release()
+      verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close()
     }
 
     // 3 local blocks, and 2 remote blocks
@@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
-      ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
-    )
+      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
 
     // Semaphore to coordinate event sequence in two different threads.
     val sem = new Semaphore(0)
@@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
-    // Exhaust the first block, and then it should be released.
-    iterator.next()._2.get.foreach(_ => Unit)
+    verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
+    iterator.next()._2.get.close() // close() first block's input stream
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
 
     // Get the 2nd block but do not exhaust the iterator
@@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
       transfer,
       blockManager,
       blocksByAddress,
-      new TestSerializer,
       48 * 1024 * 1024)
 
     // Continue only after the mock calls onBlockFetchFailure


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org