You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/09/22 20:50:26 UTC

spark git commit: [SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader

Repository: spark
Updated Branches:
  refs/heads/master 22d40159e -> 1ca5e2e0b


[SPARK-10704] Rename HashShuffleReader to BlockStoreShuffleReader

The current shuffle code has an interface named ShuffleReader with only one implementation, HashShuffleReader. This naming is confusing, since the same read path code is used for both sort- and hash-based shuffle. This patch addresses this by renaming HashShuffleReader to BlockStoreShuffleReader.

Author: Josh Rosen <jo...@databricks.com>

Closes #8825 from JoshRosen/shuffle-reader-cleanup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1ca5e2e0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1ca5e2e0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1ca5e2e0

Branch: refs/heads/master
Commit: 1ca5e2e0b8d8d406c02a74c76ae9d7fc5637c8d3
Parents: 22d4015
Author: Josh Rosen <jo...@databricks.com>
Authored: Tue Sep 22 11:50:22 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Tue Sep 22 11:50:22 2015 -0700

----------------------------------------------------------------------
 .../spark/shuffle/BlockStoreShuffleReader.scala | 111 +++++++++++++
 .../spark/shuffle/hash/HashShuffleManager.scala |   2 +-
 .../spark/shuffle/hash/HashShuffleReader.scala  | 112 --------------
 .../spark/shuffle/sort/SortShuffleManager.scala |   3 +-
 .../shuffle/BlockStoreShuffleReaderSuite.scala  | 153 ++++++++++++++++++
 .../shuffle/hash/HashShuffleReaderSuite.scala   | 154 -------------------
 6 files changed, 266 insertions(+), 269 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
new file mode 100644
index 0000000..6dc9a16
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
+import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.collection.ExternalSorter
+
+private[spark] class BlockStoreShuffleReader[K, C](
+    handle: BaseShuffleHandle[K, _, C],
+    startPartition: Int,
+    endPartition: Int,
+    context: TaskContext,
+    blockManager: BlockManager = SparkEnv.get.blockManager,
+    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
+  extends ShuffleReader[K, C] with Logging {
+
+  require(endPartition == startPartition + 1,
+    "Hash shuffle currently only supports fetching one partition")
+
+  private val dep = handle.dependency
+
+  /** Read the combined key-values for this reduce task */
+  override def read(): Iterator[Product2[K, C]] = {
+    val blockFetcherItr = new ShuffleBlockFetcherIterator(
+      context,
+      blockManager.shuffleClient,
+      blockManager,
+      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
+      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
+      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
+
+    // Wrap the streams for compression based on configuration
+    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
+      blockManager.wrapForCompression(blockId, inputStream)
+    }
+
+    val ser = Serializer.getSerializer(dep.serializer)
+    val serializerInstance = ser.newInstance()
+
+    // Create a key/value iterator for each stream
+    val recordIter = wrappedStreams.flatMap { wrappedStream =>
+      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
+      // NextIterator. The NextIterator makes sure that close() is called on the
+      // underlying InputStream when all records have been read.
+      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
+    }
+
+    // Update the context task metrics for each record read.
+    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
+      recordIter.map(record => {
+        readMetrics.incRecordsRead(1)
+        record
+      }),
+      context.taskMetrics().updateShuffleReadMetrics())
+
+    // An interruptible iterator must be used here in order to support task cancellation
+    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
+
+    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
+      if (dep.mapSideCombine) {
+        // We are reading values that are already combined
+        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
+        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
+      } else {
+        // We don't know the value type, but also don't care -- the dependency *should*
+        // have made sure its compatible w/ this aggregator, which will convert the value
+        // type to the combined type C
+        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
+        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
+      }
+    } else {
+      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
+      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
+    }
+
+    // Sort the output if there is a sort ordering defined.
+    dep.keyOrdering match {
+      case Some(keyOrd: Ordering[K]) =>
+        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
+        // the ExternalSorter won't spill to disk.
+        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
+        sorter.insertAll(aggregatedIter)
+        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
+        context.internalMetricsToAccumulators(
+          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
+        sorter.iterator
+      case None =>
+        aggregatedIter
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index 0b46634..d2e2fc4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -51,7 +51,7 @@ private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager
       startPartition: Int,
       endPartition: Int,
       context: TaskContext): ShuffleReader[K, C] = {
-    new HashShuffleReader(
+    new BlockStoreShuffleReader(
       handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
deleted file mode 100644
index 0c8f08f..0000000
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ /dev/null
@@ -1,112 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import org.apache.spark._
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator}
-import org.apache.spark.util.CompletionIterator
-import org.apache.spark.util.collection.ExternalSorter
-
-private[spark] class HashShuffleReader[K, C](
-    handle: BaseShuffleHandle[K, _, C],
-    startPartition: Int,
-    endPartition: Int,
-    context: TaskContext,
-    blockManager: BlockManager = SparkEnv.get.blockManager,
-    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
-  extends ShuffleReader[K, C] with Logging {
-
-  require(endPartition == startPartition + 1,
-    "Hash shuffle currently only supports fetching one partition")
-
-  private val dep = handle.dependency
-
-  /** Read the combined key-values for this reduce task */
-  override def read(): Iterator[Product2[K, C]] = {
-    val blockFetcherItr = new ShuffleBlockFetcherIterator(
-      context,
-      blockManager.shuffleClient,
-      blockManager,
-      mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition),
-      // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
-      SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
-
-    // Wrap the streams for compression based on configuration
-    val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) =>
-      blockManager.wrapForCompression(blockId, inputStream)
-    }
-
-    val ser = Serializer.getSerializer(dep.serializer)
-    val serializerInstance = ser.newInstance()
-
-    // Create a key/value iterator for each stream
-    val recordIter = wrappedStreams.flatMap { wrappedStream =>
-      // Note: the asKeyValueIterator below wraps a key/value iterator inside of a
-      // NextIterator. The NextIterator makes sure that close() is called on the
-      // underlying InputStream when all records have been read.
-      serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
-    }
-
-    // Update the context task metrics for each record read.
-    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
-    val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
-      recordIter.map(record => {
-        readMetrics.incRecordsRead(1)
-        record
-      }),
-      context.taskMetrics().updateShuffleReadMetrics())
-
-    // An interruptible iterator must be used here in order to support task cancellation
-    val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
-
-    val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
-      if (dep.mapSideCombine) {
-        // We are reading values that are already combined
-        val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
-        dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
-      } else {
-        // We don't know the value type, but also don't care -- the dependency *should*
-        // have made sure its compatible w/ this aggregator, which will convert the value
-        // type to the combined type C
-        val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
-        dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
-      }
-    } else {
-      require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
-      interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
-    }
-
-    // Sort the output if there is a sort ordering defined.
-    dep.keyOrdering match {
-      case Some(keyOrd: Ordering[K]) =>
-        // Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
-        // the ExternalSorter won't spill to disk.
-        val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
-        sorter.insertAll(aggregatedIter)
-        context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
-        context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
-        context.internalMetricsToAccumulators(
-          InternalAccumulator.PEAK_EXECUTION_MEMORY).add(sorter.peakMemoryUsedBytes)
-        sorter.iterator
-      case None =>
-        aggregatedIter
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 476cc1f..9df4e55 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -21,7 +21,6 @@ import java.util.concurrent.ConcurrentHashMap
 
 import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
 import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.hash.HashShuffleReader
 
 private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
 
@@ -54,7 +53,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       endPartition: Int,
       context: TaskContext): ShuffleReader[K, C] = {
     // We currently use the same block store shuffle fetcher as the hash-based shuffle.
-    new HashShuffleReader(
+    new BlockStoreShuffleReader(
       handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
new file mode 100644
index 0000000..a5eafb1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.{ByteArrayOutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import org.mockito.Matchers.{eq => meq, _}
+import org.mockito.Mockito.{mock, when}
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
+
+/**
+ * Wrapper for a managed buffer that keeps track of how many times retain and release are called.
+ *
+ * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
+ * is final (final classes cannot be spied on).
+ */
+class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
+  var callsToRetain = 0
+  var callsToRelease = 0
+
+  override def size(): Long = underlyingBuffer.size()
+  override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
+  override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
+  override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
+
+  override def retain(): ManagedBuffer = {
+    callsToRetain += 1
+    underlyingBuffer.retain()
+  }
+  override def release(): ManagedBuffer = {
+    callsToRelease += 1
+    underlyingBuffer.release()
+  }
+}
+
+class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
+
+  /**
+   * This test makes sure that, when data is read from a HashShuffleReader, the underlying
+   * ManagedBuffers that contain the data are eventually released.
+   */
+  test("read() releases resources on completion") {
+    val testConf = new SparkConf(false)
+    // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
+    // shuffle code calls SparkEnv.get()).
+    sc = new SparkContext("local", "test", testConf)
+
+    val reduceId = 15
+    val shuffleId = 22
+    val numMaps = 6
+    val keyValuePairsPerMap = 10
+    val serializer = new JavaSerializer(testConf)
+
+    // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
+    // can ensure retain() and release() are properly called.
+    val blockManager = mock(classOf[BlockManager])
+
+    // Create a return function to use for the mocked wrapForCompression method that just returns
+    // the original input stream.
+    val dummyCompressionFunction = new Answer[InputStream] {
+      override def answer(invocation: InvocationOnMock): InputStream =
+        invocation.getArguments()(1).asInstanceOf[InputStream]
+    }
+
+    // Create a buffer with some randomly generated key-value pairs to use as the shuffle data
+    // from each mappers (all mappers return the same shuffle data).
+    val byteOutputStream = new ByteArrayOutputStream()
+    val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
+    (0 until keyValuePairsPerMap).foreach { i =>
+      serializationStream.writeKey(i)
+      serializationStream.writeValue(2*i)
+    }
+
+    // Setup the mocked BlockManager to return RecordingManagedBuffers.
+    val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
+    when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
+    val buffers = (0 until numMaps).map { mapId =>
+      // Create a ManagedBuffer with the shuffle data.
+      val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
+      val managedBuffer = new RecordingManagedBuffer(nioBuffer)
+
+      // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
+      // fetch shuffle data.
+      val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+      when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
+      when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
+        .thenAnswer(dummyCompressionFunction)
+
+      managedBuffer
+    }
+
+    // Make a mocked MapOutputTracker for the shuffle reader to use to determine what
+    // shuffle data to read.
+    val mapOutputTracker = mock(classOf[MapOutputTracker])
+    when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn {
+      // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
+      // for the code to read data over the network.
+      val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
+        val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
+        (shuffleBlockId, byteOutputStream.size().toLong)
+      }
+      Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
+    }
+
+    // Create a mocked shuffle handle to pass into HashShuffleReader.
+    val shuffleHandle = {
+      val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
+      when(dependency.serializer).thenReturn(Some(serializer))
+      when(dependency.aggregator).thenReturn(None)
+      when(dependency.keyOrdering).thenReturn(None)
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+
+    val shuffleReader = new BlockStoreShuffleReader(
+      shuffleHandle,
+      reduceId,
+      reduceId + 1,
+      TaskContext.empty(),
+      blockManager,
+      mapOutputTracker)
+
+    assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
+
+    // Calling .length above will have exhausted the iterator; make sure that exhausting the
+    // iterator caused retain and release to be called on each buffer.
+    buffers.foreach { buffer =>
+      assert(buffer.callsToRetain === 1)
+      assert(buffer.callsToRelease === 1)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1ca5e2e0/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
deleted file mode 100644
index 05b3afe..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala
+++ /dev/null
@@ -1,154 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.hash
-
-import java.io.{ByteArrayOutputStream, InputStream}
-import java.nio.ByteBuffer
-
-import org.mockito.Matchers.{eq => meq, _}
-import org.mockito.Mockito.{mock, when}
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-
-import org.apache.spark._
-import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.serializer.JavaSerializer
-import org.apache.spark.shuffle.BaseShuffleHandle
-import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
-
-/**
- * Wrapper for a managed buffer that keeps track of how many times retain and release are called.
- *
- * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
- * is final (final classes cannot be spied on).
- */
-class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
-  var callsToRetain = 0
-  var callsToRelease = 0
-
-  override def size(): Long = underlyingBuffer.size()
-  override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer()
-  override def createInputStream(): InputStream = underlyingBuffer.createInputStream()
-  override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty()
-
-  override def retain(): ManagedBuffer = {
-    callsToRetain += 1
-    underlyingBuffer.retain()
-  }
-  override def release(): ManagedBuffer = {
-    callsToRelease += 1
-    underlyingBuffer.release()
-  }
-}
-
-class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
-
-  /**
-   * This test makes sure that, when data is read from a HashShuffleReader, the underlying
-   * ManagedBuffers that contain the data are eventually released.
-   */
-  test("read() releases resources on completion") {
-    val testConf = new SparkConf(false)
-    // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
-    // shuffle code calls SparkEnv.get()).
-    sc = new SparkContext("local", "test", testConf)
-
-    val reduceId = 15
-    val shuffleId = 22
-    val numMaps = 6
-    val keyValuePairsPerMap = 10
-    val serializer = new JavaSerializer(testConf)
-
-    // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
-    // can ensure retain() and release() are properly called.
-    val blockManager = mock(classOf[BlockManager])
-
-    // Create a return function to use for the mocked wrapForCompression method that just returns
-    // the original input stream.
-    val dummyCompressionFunction = new Answer[InputStream] {
-      override def answer(invocation: InvocationOnMock): InputStream =
-        invocation.getArguments()(1).asInstanceOf[InputStream]
-    }
-
-    // Create a buffer with some randomly generated key-value pairs to use as the shuffle data
-    // from each mappers (all mappers return the same shuffle data).
-    val byteOutputStream = new ByteArrayOutputStream()
-    val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
-    (0 until keyValuePairsPerMap).foreach { i =>
-      serializationStream.writeKey(i)
-      serializationStream.writeValue(2*i)
-    }
-
-    // Setup the mocked BlockManager to return RecordingManagedBuffers.
-    val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
-    when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
-    val buffers = (0 until numMaps).map { mapId =>
-      // Create a ManagedBuffer with the shuffle data.
-      val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
-      val managedBuffer = new RecordingManagedBuffer(nioBuffer)
-
-      // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
-      // fetch shuffle data.
-      val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
-      when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
-      when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
-        .thenAnswer(dummyCompressionFunction)
-
-      managedBuffer
-    }
-
-    // Make a mocked MapOutputTracker for the shuffle reader to use to determine what
-    // shuffle data to read.
-    val mapOutputTracker = mock(classOf[MapOutputTracker])
-    when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn {
-      // Test a scenario where all data is local, to avoid creating a bunch of additional mocks
-      // for the code to read data over the network.
-      val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId =>
-        val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
-        (shuffleBlockId, byteOutputStream.size().toLong)
-      }
-      Seq((localBlockManagerId, shuffleBlockIdsAndSizes))
-    }
-
-    // Create a mocked shuffle handle to pass into HashShuffleReader.
-    val shuffleHandle = {
-      val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
-      when(dependency.serializer).thenReturn(Some(serializer))
-      when(dependency.aggregator).thenReturn(None)
-      when(dependency.keyOrdering).thenReturn(None)
-      new BaseShuffleHandle(shuffleId, numMaps, dependency)
-    }
-
-    val shuffleReader = new HashShuffleReader(
-      shuffleHandle,
-      reduceId,
-      reduceId + 1,
-      TaskContext.empty(),
-      blockManager,
-      mapOutputTracker)
-
-    assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
-
-    // Calling .length above will have exhausted the iterator; make sure that exhausting the
-    // iterator caused retain and release to be called on each buffer.
-    buffers.foreach { buffer =>
-      assert(buffer.callsToRetain === 1)
-      assert(buffer.callsToRelease === 1)
-    }
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org