You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jo...@apache.org on 2015/10/22 18:46:39 UTC

[1/4] spark git commit: [SPARK-10708] Consolidate sort shuffle implementations

Repository: spark
Updated Branches:
  refs/heads/master 94e2064fa -> f6d06adf0


http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 341f56d..b92a302 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -33,7 +33,8 @@ import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
 import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
-import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.serializer.{JavaSerializer, SerializerInstance}
 import org.apache.spark.storage._
 import org.apache.spark.util.Utils
 
@@ -42,25 +43,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
   @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockResolver: IndexShuffleBlockResolver = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: ShuffleDependency[Int, Int, Int] = _
 
   private var taskMetrics: TaskMetrics = _
-  private var shuffleWriteMetrics: ShuffleWriteMetrics = _
   private var tempDir: File = _
   private var outputFile: File = _
   private val conf: SparkConf = new SparkConf(loadDefaults = false)
   private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
   private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
-  private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
-  private val serializer: Serializer = new JavaSerializer(conf)
+  private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _
 
   override def beforeEach(): Unit = {
     tempDir = Utils.createTempDir()
     outputFile = File.createTempFile("shuffle", null, tempDir)
-    shuffleWriteMetrics = new ShuffleWriteMetrics
     taskMetrics = new TaskMetrics
-    taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
     MockitoAnnotations.initMocks(this)
+    shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int](
+      shuffleId = 0,
+      numMaps = 2,
+      dependency = dependency
+    )
+    when(dependency.partitioner).thenReturn(new HashPartitioner(7))
+    when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
     when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+    when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
     when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
     when(blockManager.getDiskWriter(
       any[BlockId],
@@ -107,18 +114,20 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
 
   test("write empty iterator") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
-    writer.insertAll(Iterator.empty)
-    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
-    assert(partitionLengths.sum === 0)
+    writer.write(Iterator.empty)
+    writer.stop( /* success = */ true)
+    assert(writer.getPartitionLengths.sum === 0)
     assert(outputFile.exists())
     assert(outputFile.length() === 0)
     assert(temporaryFilesCreated.isEmpty)
+    val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
     assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
     assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
     assert(taskMetrics.diskBytesSpilled === 0)
@@ -129,17 +138,19 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
     def records: Iterator[(Int, Int)] =
       Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
-    writer.insertAll(records)
+    writer.write(records)
+    writer.stop( /* success = */ true)
     assert(temporaryFilesCreated.nonEmpty)
-    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
-    assert(partitionLengths.sum === outputFile.length())
+    assert(writer.getPartitionLengths.sum === outputFile.length())
     assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+    val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get
     assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
     assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
     assert(taskMetrics.diskBytesSpilled === 0)
@@ -148,14 +159,15 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
 
   test("cleanup of intermediate files after errors") {
     val writer = new BypassMergeSortShuffleWriter[Int, Int](
-      new SparkConf(loadDefaults = false),
       blockManager,
-      new HashPartitioner(7),
-      shuffleWriteMetrics,
-      serializer
+      blockResolver,
+      shuffleHandle,
+      0, // MapId
+      taskContext,
+      conf
     )
     intercept[SparkException] {
-      writer.insertAll((0 until 100000).iterator.map(i => {
+      writer.write((0 until 100000).iterator.map(i => {
         if (i == 99990) {
           throw new SparkException("Intentional failure")
         }
@@ -163,7 +175,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
       }))
     }
     assert(temporaryFilesCreated.nonEmpty)
-    writer.stop()
+    writer.stop( /* success = */ false)
     assert(temporaryFilesCreated.count(_.exists()) === 0)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
new file mode 100644
index 0000000..8744a07
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.Matchers
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
+
+  import SortShuffleManager.canUseSerializedShuffle
+
+  private class RuntimeExceptionAnswer extends Answer[Object] {
+    override def answer(invocation: InvocationOnMock): Object = {
+      throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+    }
+  }
+
+  private def shuffleDep(
+      partitioner: Partitioner,
+      serializer: Option[Serializer],
+      keyOrdering: Option[Ordering[Any]],
+      aggregator: Option[Aggregator[Any, Any, Any]],
+      mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+    val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+    doReturn(0).when(dep).shuffleId
+    doReturn(partitioner).when(dep).partitioner
+    doReturn(serializer).when(dep).serializer
+    doReturn(keyOrdering).when(dep).keyOrdering
+    doReturn(aggregator).when(dep).aggregator
+    doReturn(mapSideCombine).when(dep).mapSideCombine
+    dep
+  }
+
+  test("supported shuffle dependencies for serialized shuffle") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+
+    assert(canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+    when(rangePartitioner.numPartitions).thenReturn(2)
+    assert(canUseSerializedShuffle(shuffleDep(
+      partitioner = rangePartitioner,
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // Shuffles with key orderings are supported as long as no aggregator is specified
+    assert(canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+  }
+
+  test("unsupported shuffle dependencies for serialized shuffle") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+    val java = Some(new JavaSerializer(new SparkConf()))
+
+    // We only support serializers that support object relocation
+    assert(!canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = java,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // The serialized shuffle path do not support shuffles with more than 16 million output
+    // partitions, due to a limitation in its sorter implementation.
+    assert(!canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(
+        SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE + 1),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles that perform aggregation
+    assert(!canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = false
+    )))
+    assert(!canUseSerializedShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = true
+    )))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
deleted file mode 100644
index 34b4984..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort
-
-import org.mockito.Mockito._
-
-import org.apache.spark.{Aggregator, SparkConf, SparkFunSuite}
-
-class SortShuffleWriterSuite extends SparkFunSuite {
-
-  import SortShuffleWriter._
-
-  test("conditions for bypassing merge-sort") {
-    val conf = new SparkConf(loadDefaults = false)
-    val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
-    val ord = implicitly[Ordering[Int]]
-
-    // Numbers of partitions that are above and below the default bypassMergeThreshold
-    val FEW_PARTITIONS = 50
-    val MANY_PARTITIONS = 10000
-
-    // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
-    assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
-    assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
-
-    // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
-    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
-    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
deleted file mode 100644
index 6727934..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
+++ /dev/null
@@ -1,129 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import org.mockito.Mockito._
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers
-
-import org.apache.spark._
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
-
-/**
- * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
- * performed in other suites.
- */
-class UnsafeShuffleManagerSuite extends SparkFunSuite with Matchers {
-
-  import UnsafeShuffleManager.canUseUnsafeShuffle
-
-  private class RuntimeExceptionAnswer extends Answer[Object] {
-    override def answer(invocation: InvocationOnMock): Object = {
-      throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
-    }
-  }
-
-  private def shuffleDep(
-      partitioner: Partitioner,
-      serializer: Option[Serializer],
-      keyOrdering: Option[Ordering[Any]],
-      aggregator: Option[Aggregator[Any, Any, Any]],
-      mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
-    val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
-    doReturn(0).when(dep).shuffleId
-    doReturn(partitioner).when(dep).partitioner
-    doReturn(serializer).when(dep).serializer
-    doReturn(keyOrdering).when(dep).keyOrdering
-    doReturn(aggregator).when(dep).aggregator
-    doReturn(mapSideCombine).when(dep).mapSideCombine
-    dep
-  }
-
-  test("supported shuffle dependencies") {
-    val kryo = Some(new KryoSerializer(new SparkConf()))
-
-    assert(canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(2),
-      serializer = kryo,
-      keyOrdering = None,
-      aggregator = None,
-      mapSideCombine = false
-    )))
-
-    val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
-    when(rangePartitioner.numPartitions).thenReturn(2)
-    assert(canUseUnsafeShuffle(shuffleDep(
-      partitioner = rangePartitioner,
-      serializer = kryo,
-      keyOrdering = None,
-      aggregator = None,
-      mapSideCombine = false
-    )))
-
-    // Shuffles with key orderings are supported as long as no aggregator is specified
-    assert(canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(2),
-      serializer = kryo,
-      keyOrdering = Some(mock(classOf[Ordering[Any]])),
-      aggregator = None,
-      mapSideCombine = false
-    )))
-
-  }
-
-  test("unsupported shuffle dependencies") {
-    val kryo = Some(new KryoSerializer(new SparkConf()))
-    val java = Some(new JavaSerializer(new SparkConf()))
-
-    // We only support serializers that support object relocation
-    assert(!canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(2),
-      serializer = java,
-      keyOrdering = None,
-      aggregator = None,
-      mapSideCombine = false
-    )))
-
-    // We do not support shuffles with more than 16 million output partitions
-    assert(!canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
-      serializer = kryo,
-      keyOrdering = None,
-      aggregator = None,
-      mapSideCombine = false
-    )))
-
-    // We do not support shuffles that perform aggregation
-    assert(!canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(2),
-      serializer = kryo,
-      keyOrdering = None,
-      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
-      mapSideCombine = false
-    )))
-    assert(!canUseUnsafeShuffle(shuffleDep(
-      partitioner = new HashPartitioner(2),
-      serializer = kryo,
-      keyOrdering = Some(mock(classOf[Ordering[Any]])),
-      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
-      mapSideCombine = true
-    )))
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
deleted file mode 100644
index 259020a..0000000
--- a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.io.File
-
-import scala.collection.JavaConverters._
-
-import org.apache.commons.io.FileUtils
-import org.apache.commons.io.filefilter.TrueFileFilter
-import org.scalatest.BeforeAndAfterAll
-
-import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
-import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
-import org.apache.spark.util.Utils
-
-class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
-
-  // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
-
-  override def beforeAll() {
-    conf.set("spark.shuffle.manager", "tungsten-sort")
-  }
-
-  test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
-    val tmpDir = Utils.createTempDir()
-    try {
-      val myConf = conf.clone()
-        .set("spark.local.dir", tmpDir.getAbsolutePath)
-      sc = new SparkContext("local", "test", myConf)
-      // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
-      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
-      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
-        .setSerializer(new KryoSerializer(myConf))
-      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
-      assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
-      def getAllFiles: Set[File] =
-        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
-      val filesBeforeShuffle = getAllFiles
-      // Force the shuffle to be performed
-      shuffledRdd.count()
-      // Ensure that the shuffle actually created files that will need to be cleaned up
-      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
-      filesCreatedByShuffle.map(_.getName) should be
-        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
-      // Check that the cleanup actually removes the files
-      sc.env.blockManager.master.removeShuffle(0, blocking = true)
-      for (file <- filesCreatedByShuffle) {
-        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
-      }
-    } finally {
-      Utils.deleteRecursively(tmpDir)
-    }
-  }
-
-  test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
-    val tmpDir = Utils.createTempDir()
-    try {
-      val myConf = conf.clone()
-        .set("spark.local.dir", tmpDir.getAbsolutePath)
-      sc = new SparkContext("local", "test", myConf)
-      // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
-      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
-      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
-        .setSerializer(new JavaSerializer(myConf))
-      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
-      assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
-      def getAllFiles: Set[File] =
-        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
-      val filesBeforeShuffle = getAllFiles
-      // Force the shuffle to be performed
-      shuffledRdd.count()
-      // Ensure that the shuffle actually created files that will need to be cleaned up
-      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
-      filesCreatedByShuffle.map(_.getName) should be
-        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
-      // Check that the cleanup actually removes the files
-      sc.env.blockManager.master.removeShuffle(0, blocking = true)
-      for (file <- filesCreatedByShuffle) {
-        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
-      }
-    } finally {
-      Utils.deleteRecursively(tmpDir)
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
deleted file mode 100644
index 05306f4..0000000
--- a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
+++ /dev/null
@@ -1,144 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.nio.ByteBuffer
-
-import org.scalatest.Matchers._
-
-import org.apache.spark.SparkFunSuite
-
-class ChainedBufferSuite extends SparkFunSuite {
-  test("write and read at start") {
-    // write from start of source array
-    val buffer = new ChainedBuffer(8)
-    buffer.capacity should be (0)
-    verifyWriteAndRead(buffer, 0, 0, 0, 4)
-    buffer.capacity should be (8)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 0, 5, 0, 4)
-    buffer.capacity should be (8)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 0, 0, 5, 4)
-    buffer.capacity should be (8)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 0, 0, 0, 8)
-    buffer.capacity should be (8)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 0, 0, 0, 12)
-    buffer.capacity should be (16)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 0, 0, 0, 28)
-    buffer.capacity should be (32)
-  }
-
-  test("write and read at middle") {
-    val buffer = new ChainedBuffer(8)
-
-    // fill to a middle point
-    verifyWriteAndRead(buffer, 0, 0, 0, 3)
-
-    // write from start of source array
-    verifyWriteAndRead(buffer, 3, 0, 0, 4)
-    buffer.capacity should be (8)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 3, 5, 0, 4)
-    buffer.capacity should be (8)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 3, 0, 5, 4)
-    buffer.capacity should be (8)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 3, 0, 0, 5)
-    buffer.capacity should be (8)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 3, 0, 0, 12)
-    buffer.capacity should be (16)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 3, 0, 0, 28)
-    buffer.capacity should be (32)
-  }
-
-  test("write and read at later buffer") {
-    val buffer = new ChainedBuffer(8)
-
-    // fill to a middle point
-    verifyWriteAndRead(buffer, 0, 0, 0, 11)
-
-    // write from start of source array
-    verifyWriteAndRead(buffer, 11, 0, 0, 4)
-    buffer.capacity should be (16)
-
-    // write from middle of source array
-    verifyWriteAndRead(buffer, 11, 5, 0, 4)
-    buffer.capacity should be (16)
-
-    // read to middle of target array
-    verifyWriteAndRead(buffer, 11, 0, 5, 4)
-    buffer.capacity should be (16)
-
-    // write up to border
-    verifyWriteAndRead(buffer, 11, 0, 0, 5)
-    buffer.capacity should be (16)
-
-    // expand into second buffer
-    verifyWriteAndRead(buffer, 11, 0, 0, 12)
-    buffer.capacity should be (24)
-
-    // expand into multiple buffers
-    verifyWriteAndRead(buffer, 11, 0, 0, 28)
-    buffer.capacity should be (40)
-  }
-
-
-  // Used to make sure we're writing different bytes each time
-  var rangeStart = 0
-
-  /**
-   * @param buffer The buffer to write to and read from.
-   * @param offsetInBuffer The offset to write to in the buffer.
-   * @param offsetInSource The offset in the array that the bytes are written from.
-   * @param offsetInTarget The offset in the array to read the bytes into.
-   * @param length The number of bytes to read and write
-   */
-  def verifyWriteAndRead(
-      buffer: ChainedBuffer,
-      offsetInBuffer: Int,
-      offsetInSource: Int,
-      offsetInTarget: Int,
-      length: Int): Unit = {
-    val source = new Array[Byte](offsetInSource + length)
-    (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
-    buffer.write(offsetInBuffer, source, offsetInSource, length)
-    val target = new Array[Byte](offsetInTarget + length)
-    buffer.read(offsetInBuffer, target, offsetInTarget, length)
-    ByteBuffer.wrap(source, offsetInSource, length) should be
-      (ByteBuffer.wrap(target, offsetInTarget, length))
-
-    rangeStart += 100
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
deleted file mode 100644
index 3b67f62..0000000
--- a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
-
-import com.google.common.io.ByteStreams
-
-import org.mockito.Matchers.any
-import org.mockito.Mockito._
-import org.mockito.Mockito.RETURNS_SMART_NULLS
-import org.mockito.invocation.InvocationOnMock
-import org.mockito.stubbing.Answer
-import org.scalatest.Matchers._
-
-import org.apache.spark.{SparkConf, SparkFunSuite}
-import org.apache.spark.serializer.KryoSerializer
-import org.apache.spark.storage.DiskBlockObjectWriter
-
-class PartitionedSerializedPairBufferSuite extends SparkFunSuite {
-  test("OrderedInputStream single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-
-    val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
-
-    val baos = new ByteArrayOutputStream()
-    val stream = serializerInstance.serializeStream(baos)
-    stream.writeObject(10)
-    stream.writeObject(struct)
-    stream.close()
-
-    baos.toByteArray should be (bytes)
-  }
-
-  test("insert single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
-    elements.size should be (1)
-    elements.head should be (((4, 10), struct))
-  }
-
-  test("insert multiple records") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct1 = SomeStruct("something1", 8)
-    buffer.insert(6, 1, struct1)
-    val struct2 = SomeStruct("something2", 9)
-    buffer.insert(4, 2, struct2)
-    val struct3 = SomeStruct("something3", 10)
-    buffer.insert(5, 3, struct3)
-
-    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
-    elements.size should be (3)
-    elements(0) should be (((4, 2), struct2))
-    elements(1) should be (((5, 3), struct3))
-    elements(2) should be (((6, 1), struct1))
-  }
-
-  test("write single record") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct = SomeStruct("something", 5)
-    buffer.insert(4, 10, struct)
-    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val (writer, baos) = createMockWriter()
-    assert(it.hasNext)
-    it.nextPartition should be (4)
-    it.writeNext(writer)
-    assert(!it.hasNext)
-
-    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
-    stream.readObject[AnyRef]() should be (10)
-    stream.readObject[AnyRef]() should be (struct)
-  }
-
-  test("write multiple records") {
-    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
-    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
-    val struct1 = SomeStruct("something1", 8)
-    buffer.insert(6, 1, struct1)
-    val struct2 = SomeStruct("something2", 9)
-    buffer.insert(4, 2, struct2)
-    val struct3 = SomeStruct("something3", 10)
-    buffer.insert(5, 3, struct3)
-
-    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
-    val (writer, baos) = createMockWriter()
-    assert(it.hasNext)
-    it.nextPartition should be (4)
-    it.writeNext(writer)
-    assert(it.hasNext)
-    it.nextPartition should be (5)
-    it.writeNext(writer)
-    assert(it.hasNext)
-    it.nextPartition should be (6)
-    it.writeNext(writer)
-    assert(!it.hasNext)
-
-    val stream = serializerInstance.deserializeStream(new ByteArrayInputStream(baos.toByteArray))
-    val iter = stream.asIterator
-    iter.next() should be (2)
-    iter.next() should be (struct2)
-    iter.next() should be (3)
-    iter.next() should be (struct3)
-    iter.next() should be (1)
-    iter.next() should be (struct1)
-    assert(!iter.hasNext)
-  }
-
-  def createMockWriter(): (DiskBlockObjectWriter, ByteArrayOutputStream) = {
-    val writer = mock(classOf[DiskBlockObjectWriter], RETURNS_SMART_NULLS)
-    val baos = new ByteArrayOutputStream()
-    when(writer.write(any(), any(), any())).thenAnswer(new Answer[Unit] {
-      override def answer(invocationOnMock: InvocationOnMock): Unit = {
-        val args = invocationOnMock.getArguments
-        val bytes = args(0).asInstanceOf[Array[Byte]]
-        val offset = args(1).asInstanceOf[Int]
-        val length = args(2).asInstanceOf[Int]
-        baos.write(bytes, offset, length)
-      }
-    })
-    (writer, baos)
-  }
-}
-
-case class SomeStruct(str: String, num: Int)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/docs/configuration.md
----------------------------------------------------------------------
diff --git a/docs/configuration.md b/docs/configuration.md
index 46d92ce..be9c36b 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -437,12 +437,9 @@ Apart from these, the following properties are also available, and may be useful
   <td><code>spark.shuffle.manager</code></td>
   <td>sort</td>
   <td>
-    Implementation to use for shuffling data. There are three implementations available:
-    <code>sort</code>, <code>hash</code> and the new (1.5+) <code>tungsten-sort</code>.
+    Implementation to use for shuffling data. There are two implementations available:
+    <code>sort</code> and <code>hash</code>.
     Sort-based shuffle is more memory-efficient and is the default option starting in 1.2.
-    Tungsten-sort is similar to the sort based shuffle, with a direct binary cache-friendly
-    implementation with a fall back to regular sort based shuffle if its requirements are not
-    met.
   </td>
 </tr>
 <tr>

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0872d3f..b5e661d 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -37,6 +37,7 @@ object MimaExcludes {
       Seq(
         MimaBuild.excludeSparkPackage("deploy"),
         MimaBuild.excludeSparkPackage("network"),
+        MimaBuild.excludeSparkPackage("unsafe"),
         // These are needed if checking against the sbt build, since they are part of
         // the maven-generated artifacts in 1.3.
         excludePackage("org.spark-project.jetty"),
@@ -44,7 +45,11 @@ object MimaExcludes {
         // SQL execution is considered private.
         excludePackage("org.apache.spark.sql.execution"),
         // SQL columnar is considered private.
-        excludePackage("org.apache.spark.sql.columnar")
+        excludePackage("org.apache.spark.sql.columnar"),
+        // The shuffle package is considered private.
+        excludePackage("org.apache.spark.shuffle"),
+        // The collections utlities are considered pricate.
+        excludePackage("org.apache.spark.util.collection")
       ) ++
       MimaBuild.excludeSparkClass("streaming.flume.FlumeTestUtils") ++
       MimaBuild.excludeSparkClass("streaming.flume.PollingFlumeTestUtils") ++
@@ -750,4 +755,4 @@ object MimaExcludes {
       MimaBuild.excludeSparkClass("mllib.regression.LinearRegressionWithSGD")
     case _ => Seq()
   }
-}
\ No newline at end of file
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 1d3379a..7f60c8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -24,7 +24,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.hash.HashShuffleManager
 import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
 import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.errors.attachTree
@@ -87,10 +86,8 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
     // fewer partitions (like RangePartitioner, for example).
     val conf = child.sqlContext.sparkContext.conf
     val shuffleManager = SparkEnv.get.shuffleManager
-    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
-      shuffleManager.isInstanceOf[UnsafeShuffleManager]
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
     if (sortBasedShuffleOn) {
       val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
       if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
@@ -99,22 +96,18 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
         // doesn't buffer deserialized records.
         // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
         false
-      } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
-        // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
-        // them. This optimization is guarded by a feature-flag and is only applied in cases where
-        // shuffle dependency does not specify an aggregator or ordering and the record serializer
-        // has certain properties. If this optimization is enabled, we can safely avoid the copy.
+      } else if (serializer.supportsRelocationOfSerializedObjects) {
+        // SPARK-4550 and  SPARK-7081 extended sort-based shuffle to serialize individual records
+        // prior to sorting them. This optimization is only applied in cases where shuffle
+        // dependency does not specify an aggregator or ordering and the record serializer has
+        // certain properties. If this optimization is enabled, we can safely avoid the copy.
         //
         // Exchange never configures its ShuffledRDDs with aggregators or key orderings, so we only
         // need to check whether the optimization is enabled and supported by our serializer.
-        //
-        // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
         false
       } else {
-        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
-        // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
-        // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
-        // both cases, we must copy.
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory, so we must
+        // copy.
         true
       }
     } else if (shuffleManager.isInstanceOf[HashShuffleManager]) {

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 75d1fce..1680d7e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -101,7 +101,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
     val oldEnv = SparkEnv.get // save the old SparkEnv, as it will be overwritten
     Utils.tryWithSafeFinally {
       val conf = new SparkConf()
-        .set("spark.shuffle.spill.initialMemoryThreshold", "1024")
+        .set("spark.shuffle.spill.initialMemoryThreshold", "1")
         .set("spark.shuffle.sort.bypassMergeThreshold", "0")
         .set("spark.testing.memory", "80000")
 
@@ -109,7 +109,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
       outputFile = File.createTempFile("test-unsafe-row-serializer-spill", "")
       // prepare data
       val converter = unsafeRowConverter(Array(IntegerType))
-      val data = (1 to 1000).iterator.map { i =>
+      val data = (1 to 10000).iterator.map { i =>
         (i, converter(Row(i)))
       }
       val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow](
@@ -141,9 +141,8 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext {
     }
   }
 
-  test("SPARK-10403: unsafe row serializer with UnsafeShuffleManager") {
-    val conf = new SparkConf()
-      .set("spark.shuffle.manager", "tungsten-sort")
+  test("SPARK-10403: unsafe row serializer with SortShuffleManager") {
+    val conf = new SparkConf().set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local", "test", conf)
     val row = Row("Hello", 123)
     val unsafeRow = toUnsafeRow(row, Array(StringType, IntegerType))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[3/4] spark git commit: [SPARK-10708] Consolidate sort shuffle implementations

Posted by jo...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
deleted file mode 100644
index e73ba39..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
+++ /dev/null
@@ -1,479 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import javax.annotation.Nullable;
-import java.io.File;
-import java.io.IOException;
-import java.util.LinkedList;
-
-import scala.Tuple2;
-
-import com.google.common.annotations.VisibleForTesting;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.serializer.DummySerializerInstance;
-import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.DiskBlockObjectWriter;
-import org.apache.spark.storage.TempShuffleBlockId;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.array.ByteArrayMethods;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.Utils;
-
-/**
- * An external sorter that is specialized for sort-based shuffle.
- * <p>
- * Incoming records are appended to data pages. When all records have been inserted (or when the
- * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
- * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
- * written to a single output file (or multiple files, if we've spilled). The format of the output
- * files is the same as the format of the final output file written by
- * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
- * written as a single serialized, compressed stream that can be read with a new decompression and
- * deserialization stream.
- * <p>
- * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
- * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
- * specialized merge procedure that avoids extra serialization/deserialization.
- */
-final class UnsafeShuffleExternalSorter {
-
-  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
-
-  @VisibleForTesting
-  static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
-
-  private final int initialSize;
-  private final int numPartitions;
-  private final int pageSizeBytes;
-  @VisibleForTesting
-  final int maxRecordSizeBytes;
-  private final TaskMemoryManager taskMemoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
-  private final BlockManager blockManager;
-  private final TaskContext taskContext;
-  private final ShuffleWriteMetrics writeMetrics;
-
-  /** The buffer size to use when writing spills using DiskBlockObjectWriter */
-  private final int fileBufferSizeBytes;
-
-  /**
-   * Memory pages that hold the records being sorted. The pages in this list are freed when
-   * spilling, although in principle we could recycle these pages across spills (on the other hand,
-   * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
-   * itself).
-   */
-  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
-
-  private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
-
-  /** Peak memory used by this sorter so far, in bytes. **/
-  private long peakMemoryUsedBytes;
-
-  // These variables are reset after spilling:
-  @Nullable private UnsafeShuffleInMemorySorter inMemSorter;
-  @Nullable private MemoryBlock currentPage = null;
-  private long currentPagePosition = -1;
-  private long freeSpaceInCurrentPage = 0;
-
-  public UnsafeShuffleExternalSorter(
-      TaskMemoryManager memoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
-      BlockManager blockManager,
-      TaskContext taskContext,
-      int initialSize,
-      int numPartitions,
-      SparkConf conf,
-      ShuffleWriteMetrics writeMetrics) throws IOException {
-    this.taskMemoryManager = memoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
-    this.blockManager = blockManager;
-    this.taskContext = taskContext;
-    this.initialSize = initialSize;
-    this.peakMemoryUsedBytes = initialSize;
-    this.numPartitions = numPartitions;
-    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
-    this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
-    this.pageSizeBytes = (int) Math.min(
-      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
-    this.maxRecordSizeBytes = pageSizeBytes - 4;
-    this.writeMetrics = writeMetrics;
-    initializeForWriting();
-
-    // preserve first page to ensure that we have at least one page to work with. Otherwise,
-    // other operators in the same task may starve this sorter (SPARK-9709).
-    acquireNewPageIfNecessary(pageSizeBytes);
-  }
-
-  /**
-   * Allocates new sort data structures. Called when creating the sorter and after each spill.
-   */
-  private void initializeForWriting() throws IOException {
-    // TODO: move this sizing calculation logic into a static method of sorter:
-    final long memoryRequested = initialSize * 8L;
-    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
-    if (memoryAcquired != memoryRequested) {
-      shuffleMemoryManager.release(memoryAcquired);
-      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
-    }
-
-    this.inMemSorter = new UnsafeShuffleInMemorySorter(initialSize);
-  }
-
-  /**
-   * Sorts the in-memory records and writes the sorted records to an on-disk file.
-   * This method does not free the sort data structures.
-   *
-   * @param isLastFile if true, this indicates that we're writing the final output file and that the
-   *                   bytes written should be counted towards shuffle spill metrics rather than
-   *                   shuffle write metrics.
-   */
-  private void writeSortedFile(boolean isLastFile) throws IOException {
-
-    final ShuffleWriteMetrics writeMetricsToUse;
-
-    if (isLastFile) {
-      // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
-      writeMetricsToUse = writeMetrics;
-    } else {
-      // We're spilling, so bytes written should be counted towards spill rather than write.
-      // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
-      // them towards shuffle bytes written.
-      writeMetricsToUse = new ShuffleWriteMetrics();
-    }
-
-    // This call performs the actual sort.
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
-      inMemSorter.getSortedIterator();
-
-    // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
-    // after SPARK-5581 is fixed.
-    DiskBlockObjectWriter writer;
-
-    // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
-    // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
-    // data through a byte array. This array does not need to be large enough to hold a single
-    // record;
-    final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
-
-    // Because this output will be read during shuffle, its compression codec must be controlled by
-    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
-    // createTempShuffleBlock here; see SPARK-3426 for more details.
-    final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
-      blockManager.diskBlockManager().createTempShuffleBlock();
-    final File file = spilledFileInfo._2();
-    final TempShuffleBlockId blockId = spilledFileInfo._1();
-    final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
-
-    // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
-    // Our write path doesn't actually use this serializer (since we end up calling the `write()`
-    // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
-    // around this, we pass a dummy no-op serializer.
-    final SerializerInstance ser = DummySerializerInstance.INSTANCE;
-
-    writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
-
-    int currentPartition = -1;
-    while (sortedRecords.hasNext()) {
-      sortedRecords.loadNext();
-      final int partition = sortedRecords.packedRecordPointer.getPartitionId();
-      assert (partition >= currentPartition);
-      if (partition != currentPartition) {
-        // Switch to the new partition
-        if (currentPartition != -1) {
-          writer.commitAndClose();
-          spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
-        }
-        currentPartition = partition;
-        writer =
-          blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
-      }
-
-      final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
-      final Object recordPage = taskMemoryManager.getPage(recordPointer);
-      final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
-      int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
-      long recordReadPosition = recordOffsetInPage + 4; // skip over record length
-      while (dataRemaining > 0) {
-        final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
-        Platform.copyMemory(
-          recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
-        writer.write(writeBuffer, 0, toTransfer);
-        recordReadPosition += toTransfer;
-        dataRemaining -= toTransfer;
-      }
-      writer.recordWritten();
-    }
-
-    if (writer != null) {
-      writer.commitAndClose();
-      // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
-      // then the file might be empty. Note that it might be better to avoid calling
-      // writeSortedFile() in that case.
-      if (currentPartition != -1) {
-        spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
-        spills.add(spillInfo);
-      }
-    }
-
-    if (!isLastFile) {  // i.e. this is a spill file
-      // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
-      // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
-      // relies on its `recordWritten()` method being called in order to trigger periodic updates to
-      // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
-      // counter at a higher-level, then the in-progress metrics for records written and bytes
-      // written would get out of sync.
-      //
-      // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
-      // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
-      // metrics to the true write metrics here. The reason for performing this copying is so that
-      // we can avoid reporting spilled bytes as shuffle write bytes.
-      //
-      // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
-      // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
-      // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
-      writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
-      taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
-    }
-  }
-
-  /**
-   * Sort and spill the current records in response to memory pressure.
-   */
-  @VisibleForTesting
-  void spill() throws IOException {
-    logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
-      Thread.currentThread().getId(),
-      Utils.bytesToString(getMemoryUsage()),
-      spills.size(),
-      spills.size() > 1 ? " times" : " time");
-
-    writeSortedFile(false);
-    final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
-    inMemSorter = null;
-    shuffleMemoryManager.release(inMemSorterMemoryUsage);
-    final long spillSize = freeMemory();
-    taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
-
-    initializeForWriting();
-  }
-
-  private long getMemoryUsage() {
-    long totalPageSize = 0;
-    for (MemoryBlock page : allocatedPages) {
-      totalPageSize += page.size();
-    }
-    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
-  }
-
-  private void updatePeakMemoryUsed() {
-    long mem = getMemoryUsage();
-    if (mem > peakMemoryUsedBytes) {
-      peakMemoryUsedBytes = mem;
-    }
-  }
-
-  /**
-   * Return the peak memory used so far, in bytes.
-   */
-  long getPeakMemoryUsedBytes() {
-    updatePeakMemoryUsed();
-    return peakMemoryUsedBytes;
-  }
-
-  private long freeMemory() {
-    updatePeakMemoryUsed();
-    long memoryFreed = 0;
-    for (MemoryBlock block : allocatedPages) {
-      taskMemoryManager.freePage(block);
-      shuffleMemoryManager.release(block.size());
-      memoryFreed += block.size();
-    }
-    allocatedPages.clear();
-    currentPage = null;
-    currentPagePosition = -1;
-    freeSpaceInCurrentPage = 0;
-    return memoryFreed;
-  }
-
-  /**
-   * Force all memory and spill files to be deleted; called by shuffle error-handling code.
-   */
-  public void cleanupResources() {
-    freeMemory();
-    for (SpillInfo spill : spills) {
-      if (spill.file.exists() && !spill.file.delete()) {
-        logger.error("Unable to delete spill file {}", spill.file.getPath());
-      }
-    }
-    if (inMemSorter != null) {
-      shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
-      inMemSorter = null;
-    }
-  }
-
-  /**
-   * Checks whether there is enough space to insert an additional record in to the sort pointer
-   * array and grows the array if additional space is required. If the required space cannot be
-   * obtained, then the in-memory data will be spilled to disk.
-   */
-  private void growPointerArrayIfNecessary() throws IOException {
-    assert(inMemSorter != null);
-    if (!inMemSorter.hasSpaceForAnotherRecord()) {
-      logger.debug("Attempting to expand sort pointer array");
-      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
-      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
-      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
-      if (memoryAcquired < memoryToGrowPointerArray) {
-        shuffleMemoryManager.release(memoryAcquired);
-        spill();
-      } else {
-        inMemSorter.expandPointerArray();
-        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
-      }
-    }
-  }
-  
-  /**
-   * Allocates more memory in order to insert an additional record. This will request additional
-   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
-   * obtained.
-   *
-   * @param requiredSpace the required space in the data page, in bytes, including space for storing
-   *                      the record size. This must be less than or equal to the page size (records
-   *                      that exceed the page size are handled via a different code path which uses
-   *                      special overflow pages).
-   */
-  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
-    growPointerArrayIfNecessary();
-    if (requiredSpace > freeSpaceInCurrentPage) {
-      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
-        freeSpaceInCurrentPage);
-      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
-      // without using the free space at the end of the current page. We should also do this for
-      // BytesToBytesMap.
-      if (requiredSpace > pageSizeBytes) {
-        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
-          pageSizeBytes + ")");
-      } else {
-        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-        if (memoryAcquired < pageSizeBytes) {
-          shuffleMemoryManager.release(memoryAcquired);
-          spill();
-          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
-          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
-            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
-            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
-          }
-        }
-        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
-        currentPagePosition = currentPage.getBaseOffset();
-        freeSpaceInCurrentPage = pageSizeBytes;
-        allocatedPages.add(currentPage);
-      }
-    }
-  }
-
-  /**
-   * Write a record to the shuffle sorter.
-   */
-  public void insertRecord(
-      Object recordBaseObject,
-      long recordBaseOffset,
-      int lengthInBytes,
-      int partitionId) throws IOException {
-
-    growPointerArrayIfNecessary();
-    // Need 4 bytes to store the record length.
-    final int totalSpaceRequired = lengthInBytes + 4;
-
-    // --- Figure out where to insert the new record ----------------------------------------------
-
-    final MemoryBlock dataPage;
-    long dataPagePosition;
-    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
-    if (useOverflowPage) {
-      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
-      // The record is larger than the page size, so allocate a special overflow page just to hold
-      // that record.
-      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-      if (memoryGranted != overflowPageSize) {
-        shuffleMemoryManager.release(memoryGranted);
-        spill();
-        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
-        if (memoryGrantedAfterSpill != overflowPageSize) {
-          shuffleMemoryManager.release(memoryGrantedAfterSpill);
-          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
-        }
-      }
-      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
-      allocatedPages.add(overflowPage);
-      dataPage = overflowPage;
-      dataPagePosition = overflowPage.getBaseOffset();
-    } else {
-      // The record is small enough to fit in a regular data page, but the current page might not
-      // have enough space to hold it (or no pages have been allocated yet).
-      acquireNewPageIfNecessary(totalSpaceRequired);
-      dataPage = currentPage;
-      dataPagePosition = currentPagePosition;
-      // Update bookkeeping information
-      freeSpaceInCurrentPage -= totalSpaceRequired;
-      currentPagePosition += totalSpaceRequired;
-    }
-    final Object dataPageBaseObject = dataPage.getBaseObject();
-
-    final long recordAddress =
-      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
-    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
-    dataPagePosition += 4;
-    Platform.copyMemory(
-      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
-    assert(inMemSorter != null);
-    inMemSorter.insertRecord(recordAddress, partitionId);
-  }
-
-  /**
-   * Close the sorter, causing any buffered data to be sorted and written out to disk.
-   *
-   * @return metadata for the spill files written by this sorter. If no records were ever inserted
-   *         into this sorter, then this will return an empty array.
-   * @throws IOException
-   */
-  public SpillInfo[] closeAndGetSpills() throws IOException {
-    try {
-      if (inMemSorter != null) {
-        // Do not count the final file towards the spill count.
-        writeSortedFile(true);
-        freeMemory();
-      }
-      return spills.toArray(new SpillInfo[spills.size()]);
-    } catch (IOException e) {
-      cleanupResources();
-      throw e;
-    }
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
deleted file mode 100644
index 5bab501..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.util.Comparator;
-
-import org.apache.spark.util.collection.Sorter;
-
-final class UnsafeShuffleInMemorySorter {
-
-  private final Sorter<PackedRecordPointer, long[]> sorter;
-  private static final class SortComparator implements Comparator<PackedRecordPointer> {
-    @Override
-    public int compare(PackedRecordPointer left, PackedRecordPointer right) {
-      return left.getPartitionId() - right.getPartitionId();
-    }
-  }
-  private static final SortComparator SORT_COMPARATOR = new SortComparator();
-
-  /**
-   * An array of record pointers and partition ids that have been encoded by
-   * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
-   * records.
-   */
-  private long[] pointerArray;
-
-  /**
-   * The position in the pointer array where new records can be inserted.
-   */
-  private int pointerArrayInsertPosition = 0;
-
-  public UnsafeShuffleInMemorySorter(int initialSize) {
-    assert (initialSize > 0);
-    this.pointerArray = new long[initialSize];
-    this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
-  }
-
-  public void expandPointerArray() {
-    final long[] oldArray = pointerArray;
-    // Guard against overflow:
-    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
-    pointerArray = new long[newLength];
-    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
-  }
-
-  public boolean hasSpaceForAnotherRecord() {
-    return pointerArrayInsertPosition + 1 < pointerArray.length;
-  }
-
-  public long getMemoryUsage() {
-    return pointerArray.length * 8L;
-  }
-
-  /**
-   * Inserts a record to be sorted.
-   *
-   * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
-   *                      certain pointer compression techniques used by the sorter, the sort can
-   *                      only operate on pointers that point to locations in the first
-   *                      {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
-   * @param partitionId the partition id, which must be less than or equal to
-   *                    {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
-   */
-  public void insertRecord(long recordPointer, int partitionId) {
-    if (!hasSpaceForAnotherRecord()) {
-      if (pointerArray.length == Integer.MAX_VALUE) {
-        throw new IllegalStateException("Sort pointer array has reached maximum size");
-      } else {
-        expandPointerArray();
-      }
-    }
-    pointerArray[pointerArrayInsertPosition] =
-        PackedRecordPointer.packPointer(recordPointer, partitionId);
-    pointerArrayInsertPosition++;
-  }
-
-  /**
-   * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
-   */
-  public static final class UnsafeShuffleSorterIterator {
-
-    private final long[] pointerArray;
-    private final int numRecords;
-    final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
-    private int position = 0;
-
-    public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
-      this.numRecords = numRecords;
-      this.pointerArray = pointerArray;
-    }
-
-    public boolean hasNext() {
-      return position < numRecords;
-    }
-
-    public void loadNext() {
-      packedRecordPointer.set(pointerArray[position]);
-      position++;
-    }
-  }
-
-  /**
-   * Return an iterator over record pointers in sorted order.
-   */
-  public UnsafeShuffleSorterIterator getSortedIterator() {
-    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
-    return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
deleted file mode 100644
index a66d74e..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import org.apache.spark.util.collection.SortDataFormat;
-
-final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
-
-  public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
-
-  private UnsafeShuffleSortDataFormat() { }
-
-  @Override
-  public PackedRecordPointer getKey(long[] data, int pos) {
-    // Since we re-use keys, this method shouldn't be called.
-    throw new UnsupportedOperationException();
-  }
-
-  @Override
-  public PackedRecordPointer newKey() {
-    return new PackedRecordPointer();
-  }
-
-  @Override
-  public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
-    reuse.set(data[pos]);
-    return reuse;
-  }
-
-  @Override
-  public void swap(long[] data, int pos0, int pos1) {
-    final long temp = data[pos0];
-    data[pos0] = data[pos1];
-    data[pos1] = temp;
-  }
-
-  @Override
-  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
-    dst[dstPos] = src[srcPos];
-  }
-
-  @Override
-  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
-    System.arraycopy(src, srcPos, dst, dstPos, length);
-  }
-
-  @Override
-  public long[] allocate(int length) {
-    return new long[length];
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
deleted file mode 100644
index fdb309e..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
+++ /dev/null
@@ -1,489 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import javax.annotation.Nullable;
-import java.io.*;
-import java.nio.channels.FileChannel;
-import java.util.Iterator;
-
-import scala.Option;
-import scala.Product2;
-import scala.collection.JavaConverters;
-import scala.collection.immutable.Map;
-import scala.reflect.ClassTag;
-import scala.reflect.ClassTag$;
-
-import com.google.common.annotations.VisibleForTesting;
-import com.google.common.io.ByteStreams;
-import com.google.common.io.Closeables;
-import com.google.common.io.Files;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import org.apache.spark.*;
-import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.io.CompressionCodec;
-import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.scheduler.MapStatus$;
-import org.apache.spark.serializer.SerializationStream;
-import org.apache.spark.serializer.Serializer;
-import org.apache.spark.serializer.SerializerInstance;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.shuffle.ShuffleWriter;
-import org.apache.spark.storage.BlockManager;
-import org.apache.spark.storage.TimeTrackingOutputStream;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-@Private
-public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
-
-  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
-
-  private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
-
-  @VisibleForTesting
-  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
-
-  private final BlockManager blockManager;
-  private final IndexShuffleBlockResolver shuffleBlockResolver;
-  private final TaskMemoryManager memoryManager;
-  private final ShuffleMemoryManager shuffleMemoryManager;
-  private final SerializerInstance serializer;
-  private final Partitioner partitioner;
-  private final ShuffleWriteMetrics writeMetrics;
-  private final int shuffleId;
-  private final int mapId;
-  private final TaskContext taskContext;
-  private final SparkConf sparkConf;
-  private final boolean transferToEnabled;
-
-  @Nullable private MapStatus mapStatus;
-  @Nullable private UnsafeShuffleExternalSorter sorter;
-  private long peakMemoryUsedBytes = 0;
-
-  /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
-  private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
-    public MyByteArrayOutputStream(int size) { super(size); }
-    public byte[] getBuf() { return buf; }
-  }
-
-  private MyByteArrayOutputStream serBuffer;
-  private SerializationStream serOutputStream;
-
-  /**
-   * Are we in the process of stopping? Because map tasks can call stop() with success = true
-   * and then call stop() with success = false if they get an exception, we want to make sure
-   * we don't try deleting files, etc twice.
-   */
-  private boolean stopping = false;
-
-  public UnsafeShuffleWriter(
-      BlockManager blockManager,
-      IndexShuffleBlockResolver shuffleBlockResolver,
-      TaskMemoryManager memoryManager,
-      ShuffleMemoryManager shuffleMemoryManager,
-      UnsafeShuffleHandle<K, V> handle,
-      int mapId,
-      TaskContext taskContext,
-      SparkConf sparkConf) throws IOException {
-    final int numPartitions = handle.dependency().partitioner().numPartitions();
-    if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
-      throw new IllegalArgumentException(
-        "UnsafeShuffleWriter can only be used for shuffles with at most " +
-          UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
-    }
-    this.blockManager = blockManager;
-    this.shuffleBlockResolver = shuffleBlockResolver;
-    this.memoryManager = memoryManager;
-    this.shuffleMemoryManager = shuffleMemoryManager;
-    this.mapId = mapId;
-    final ShuffleDependency<K, V, V> dep = handle.dependency();
-    this.shuffleId = dep.shuffleId();
-    this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
-    this.partitioner = dep.partitioner();
-    this.writeMetrics = new ShuffleWriteMetrics();
-    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
-    this.taskContext = taskContext;
-    this.sparkConf = sparkConf;
-    this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
-    open();
-  }
-
-  @VisibleForTesting
-  public int maxRecordSizeBytes() {
-    assert(sorter != null);
-    return sorter.maxRecordSizeBytes;
-  }
-
-  private void updatePeakMemoryUsed() {
-    // sorter can be null if this writer is closed
-    if (sorter != null) {
-      long mem = sorter.getPeakMemoryUsedBytes();
-      if (mem > peakMemoryUsedBytes) {
-        peakMemoryUsedBytes = mem;
-      }
-    }
-  }
-
-  /**
-   * Return the peak memory used so far, in bytes.
-   */
-  public long getPeakMemoryUsedBytes() {
-    updatePeakMemoryUsed();
-    return peakMemoryUsedBytes;
-  }
-
-  /**
-   * This convenience method should only be called in test code.
-   */
-  @VisibleForTesting
-  public void write(Iterator<Product2<K, V>> records) throws IOException {
-    write(JavaConverters.asScalaIteratorConverter(records).asScala());
-  }
-
-  @Override
-  public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
-    // Keep track of success so we know if we encountered an exception
-    // We do this rather than a standard try/catch/re-throw to handle
-    // generic throwables.
-    boolean success = false;
-    try {
-      while (records.hasNext()) {
-        insertRecordIntoSorter(records.next());
-      }
-      closeAndWriteOutput();
-      success = true;
-    } finally {
-      if (sorter != null) {
-        try {
-          sorter.cleanupResources();
-        } catch (Exception e) {
-          // Only throw this error if we won't be masking another
-          // error.
-          if (success) {
-            throw e;
-          } else {
-            logger.error("In addition to a failure during writing, we failed during " +
-                         "cleanup.", e);
-          }
-        }
-      }
-    }
-  }
-
-  private void open() throws IOException {
-    assert (sorter == null);
-    sorter = new UnsafeShuffleExternalSorter(
-      memoryManager,
-      shuffleMemoryManager,
-      blockManager,
-      taskContext,
-      INITIAL_SORT_BUFFER_SIZE,
-      partitioner.numPartitions(),
-      sparkConf,
-      writeMetrics);
-    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
-    serOutputStream = serializer.serializeStream(serBuffer);
-  }
-
-  @VisibleForTesting
-  void closeAndWriteOutput() throws IOException {
-    assert(sorter != null);
-    updatePeakMemoryUsed();
-    serBuffer = null;
-    serOutputStream = null;
-    final SpillInfo[] spills = sorter.closeAndGetSpills();
-    sorter = null;
-    final long[] partitionLengths;
-    try {
-      partitionLengths = mergeSpills(spills);
-    } finally {
-      for (SpillInfo spill : spills) {
-        if (spill.file.exists() && ! spill.file.delete()) {
-          logger.error("Error while deleting spill file {}", spill.file.getPath());
-        }
-      }
-    }
-    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
-    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
-  }
-
-  @VisibleForTesting
-  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
-    assert(sorter != null);
-    final K key = record._1();
-    final int partitionId = partitioner.getPartition(key);
-    serBuffer.reset();
-    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
-    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
-    serOutputStream.flush();
-
-    final int serializedRecordSize = serBuffer.size();
-    assert (serializedRecordSize > 0);
-
-    sorter.insertRecord(
-      serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
-  }
-
-  @VisibleForTesting
-  void forceSorterToSpill() throws IOException {
-    assert (sorter != null);
-    sorter.spill();
-  }
-
-  /**
-   * Merge zero or more spill files together, choosing the fastest merging strategy based on the
-   * number of spills and the IO compression codec.
-   *
-   * @return the partition lengths in the merged file.
-   */
-  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
-    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
-    final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
-    final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
-    final boolean fastMergeEnabled =
-      sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
-    final boolean fastMergeIsSupported =
-      !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
-    try {
-      if (spills.length == 0) {
-        new FileOutputStream(outputFile).close(); // Create an empty file
-        return new long[partitioner.numPartitions()];
-      } else if (spills.length == 1) {
-        // Here, we don't need to perform any metrics updates because the bytes written to this
-        // output file would have already been counted as shuffle bytes written.
-        Files.move(spills[0].file, outputFile);
-        return spills[0].partitionLengths;
-      } else {
-        final long[] partitionLengths;
-        // There are multiple spills to merge, so none of these spill files' lengths were counted
-        // towards our shuffle write count or shuffle write time. If we use the slow merge path,
-        // then the final output file's size won't necessarily be equal to the sum of the spill
-        // files' sizes. To guard against this case, we look at the output file's actual size when
-        // computing shuffle bytes written.
-        //
-        // We allow the individual merge methods to report their own IO times since different merge
-        // strategies use different IO techniques.  We count IO during merge towards the shuffle
-        // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
-        // branch in ExternalSorter.
-        if (fastMergeEnabled && fastMergeIsSupported) {
-          // Compression is disabled or we are using an IO compression codec that supports
-          // decompression of concatenated compressed streams, so we can perform a fast spill merge
-          // that doesn't need to interpret the spilled bytes.
-          if (transferToEnabled) {
-            logger.debug("Using transferTo-based fast merge");
-            partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
-          } else {
-            logger.debug("Using fileStream-based fast merge");
-            partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
-          }
-        } else {
-          logger.debug("Using slow merge");
-          partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
-        }
-        // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
-        // in-memory records, we write out the in-memory records to a file but do not count that
-        // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
-        // to be counted as shuffle write, but this will lead to double-counting of the final
-        // SpillInfo's bytes.
-        writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
-        writeMetrics.incShuffleBytesWritten(outputFile.length());
-        return partitionLengths;
-      }
-    } catch (IOException e) {
-      if (outputFile.exists() && !outputFile.delete()) {
-        logger.error("Unable to delete output file {}", outputFile.getPath());
-      }
-      throw e;
-    }
-  }
-
-  /**
-   * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
-   * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
-   * cases where the IO compression codec does not support concatenation of compressed data, or in
-   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
-   * kernel bugs.
-   *
-   * @param spills the spills to merge.
-   * @param outputFile the file to write the merged data to.
-   * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
-   * @return the partition lengths in the merged file.
-   */
-  private long[] mergeSpillsWithFileStream(
-      SpillInfo[] spills,
-      File outputFile,
-      @Nullable CompressionCodec compressionCodec) throws IOException {
-    assert (spills.length >= 2);
-    final int numPartitions = partitioner.numPartitions();
-    final long[] partitionLengths = new long[numPartitions];
-    final InputStream[] spillInputStreams = new FileInputStream[spills.length];
-    OutputStream mergedFileOutputStream = null;
-
-    boolean threwException = true;
-    try {
-      for (int i = 0; i < spills.length; i++) {
-        spillInputStreams[i] = new FileInputStream(spills[i].file);
-      }
-      for (int partition = 0; partition < numPartitions; partition++) {
-        final long initialFileLength = outputFile.length();
-        mergedFileOutputStream =
-          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
-        if (compressionCodec != null) {
-          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
-        }
-
-        for (int i = 0; i < spills.length; i++) {
-          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
-          if (partitionLengthInSpill > 0) {
-            InputStream partitionInputStream =
-              new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
-            if (compressionCodec != null) {
-              partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
-            }
-            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
-          }
-        }
-        mergedFileOutputStream.flush();
-        mergedFileOutputStream.close();
-        partitionLengths[partition] = (outputFile.length() - initialFileLength);
-      }
-      threwException = false;
-    } finally {
-      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
-      // throw exceptions during cleanup if threwException == false.
-      for (InputStream stream : spillInputStreams) {
-        Closeables.close(stream, threwException);
-      }
-      Closeables.close(mergedFileOutputStream, threwException);
-    }
-    return partitionLengths;
-  }
-
-  /**
-   * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
-   * This is only safe when the IO compression codec and serializer support concatenation of
-   * serialized streams.
-   *
-   * @return the partition lengths in the merged file.
-   */
-  private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
-    assert (spills.length >= 2);
-    final int numPartitions = partitioner.numPartitions();
-    final long[] partitionLengths = new long[numPartitions];
-    final FileChannel[] spillInputChannels = new FileChannel[spills.length];
-    final long[] spillInputChannelPositions = new long[spills.length];
-    FileChannel mergedFileOutputChannel = null;
-
-    boolean threwException = true;
-    try {
-      for (int i = 0; i < spills.length; i++) {
-        spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
-      }
-      // This file needs to opened in append mode in order to work around a Linux kernel bug that
-      // affects transferTo; see SPARK-3948 for more details.
-      mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
-
-      long bytesWrittenToMergedFile = 0;
-      for (int partition = 0; partition < numPartitions; partition++) {
-        for (int i = 0; i < spills.length; i++) {
-          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
-          long bytesToTransfer = partitionLengthInSpill;
-          final FileChannel spillInputChannel = spillInputChannels[i];
-          final long writeStartTime = System.nanoTime();
-          while (bytesToTransfer > 0) {
-            final long actualBytesTransferred = spillInputChannel.transferTo(
-              spillInputChannelPositions[i],
-              bytesToTransfer,
-              mergedFileOutputChannel);
-            spillInputChannelPositions[i] += actualBytesTransferred;
-            bytesToTransfer -= actualBytesTransferred;
-          }
-          writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
-          bytesWrittenToMergedFile += partitionLengthInSpill;
-          partitionLengths[partition] += partitionLengthInSpill;
-        }
-      }
-      // Check the position after transferTo loop to see if it is in the right position and raise an
-      // exception if it is incorrect. The position will not be increased to the expected length
-      // after calling transferTo in kernel version 2.6.32. This issue is described at
-      // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
-      if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
-        throw new IOException(
-          "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
-            "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
-            " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
-            "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
-            "to disable this NIO feature."
-        );
-      }
-      threwException = false;
-    } finally {
-      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
-      // throw exceptions during cleanup if threwException == false.
-      for (int i = 0; i < spills.length; i++) {
-        assert(spillInputChannelPositions[i] == spills[i].file.length());
-        Closeables.close(spillInputChannels[i], threwException);
-      }
-      Closeables.close(mergedFileOutputChannel, threwException);
-    }
-    return partitionLengths;
-  }
-
-  @Override
-  public Option<MapStatus> stop(boolean success) {
-    try {
-      // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
-      Map<String, Accumulator<Object>> internalAccumulators =
-        taskContext.internalMetricsToAccumulators();
-      if (internalAccumulators != null) {
-        internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
-          .add(getPeakMemoryUsedBytes());
-      }
-
-      if (stopping) {
-        return Option.apply(null);
-      } else {
-        stopping = true;
-        if (success) {
-          if (mapStatus == null) {
-            throw new IllegalStateException("Cannot call stop(true) without having called write()");
-          }
-          return Option.apply(mapStatus);
-        } else {
-          // The map task failed, so delete our output data.
-          shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
-          return Option.apply(null);
-        }
-      }
-    } finally {
-      if (sorter != null) {
-        // If sorter is non-null, then this implies that we called stop() in response to an error,
-        // so we need to clean up memory and spill files created by the sorter
-        sorter.cleanupResources();
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index c329983..704158b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -330,7 +330,7 @@ object SparkEnv extends Logging {
     val shortShuffleMgrNames = Map(
       "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
       "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
-      "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
+      "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
     val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
     val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
     val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 9df4e55..1105167 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -19,9 +19,53 @@ package org.apache.spark.shuffle.sort
 
 import java.util.concurrent.ConcurrentHashMap
 
-import org.apache.spark.{Logging, SparkConf, TaskContext, ShuffleDependency}
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle._
 
+/**
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * Sort-based shuffle has two different write paths for producing its map output files:
+ *
+ *  - Serialized sorting: used when all three of the following conditions hold:
+ *    1. The shuffle dependency specifies no aggregation or output ordering.
+ *    2. The shuffle serializer supports relocation of serialized values (this is currently
+ *       supported by KryoSerializer and Spark SQL's custom serializers).
+ *    3. The shuffle produces fewer than 16777216 output partitions.
+ *  - Deserialized sorting: used to handle all other cases.
+ *
+ * -----------------------
+ * Serialized sorting mode
+ * -----------------------
+ *
+ * In the serialized sorting mode, incoming records are serialized as soon as they are passed to the
+ * shuffle writer and are buffered in a serialized form during sorting. This write path implements
+ * several optimizations:
+ *
+ *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ *    consumption and GC overheads. This optimization requires the record serializer to have certain
+ *    properties to allow serialized records to be re-ordered without requiring deserialization.
+ *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ *  - It uses a specialized cache-efficient sorter ([[ShuffleExternalSorter]]) that sorts
+ *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ *    record in the sorting array, this fits more of the array into cache.
+ *
+ *  - The spill merging procedure operates on blocks of serialized records that belong to the same
+ *    partition and does not need to deserialize records during the merge.
+ *
+ *  - When the spill compression codec supports concatenation of compressed data, the spill merge
+ *    simply concatenates the serialized and compressed spill partitions to produce the final output
+ *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ *    and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on these optimizations, see SPARK-7081.
+ */
 private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
 
   if (!conf.getBoolean("spark.shuffle.spill", true)) {
@@ -30,8 +74,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
         " Shuffle will continue to spill to disk when necessary.")
   }
 
-  private val indexShuffleBlockResolver = new IndexShuffleBlockResolver(conf)
-  private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
+  /**
+   * A mapping from shuffle ids to the number of mappers producing output for those shuffles.
+   */
+  private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]()
+
+  override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf)
 
   /**
    * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
@@ -40,7 +88,22 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       shuffleId: Int,
       numMaps: Int,
       dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
-    new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    if (SortShuffleWriter.shouldBypassMergeSort(SparkEnv.get.conf, dependency)) {
+      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+      // need map-side aggregation, then write numPartitions files directly and just concatenate
+      // them at the end. This avoids doing serialization and deserialization twice to merge
+      // together the spilled files, which would happen with the normal code path. The downside is
+      // having multiple files open at a time and thus more memory allocated to buffers.
+      new BypassMergeSortShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else if (SortShuffleManager.canUseSerializedShuffle(dependency)) {
+      // Otherwise, try to buffer map outputs in a serialized form, since this is more efficient:
+      new SerializedShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else {
+      // Otherwise, buffer map outputs in a deserialized form:
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
   }
 
   /**
@@ -52,38 +115,114 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
       startPartition: Int,
       endPartition: Int,
       context: TaskContext): ShuffleReader[K, C] = {
-    // We currently use the same block store shuffle fetcher as the hash-based shuffle.
     new BlockStoreShuffleReader(
       handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
   }
 
   /** Get a writer for a given partition. Called on executors by map tasks. */
-  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
-      : ShuffleWriter[K, V] = {
-    val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, V, _]]
-    shuffleMapNumber.putIfAbsent(baseShuffleHandle.shuffleId, baseShuffleHandle.numMaps)
-    new SortShuffleWriter(
-      shuffleBlockResolver, baseShuffleHandle, mapId, context)
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext): ShuffleWriter[K, V] = {
+    numMapsForShuffle.putIfAbsent(
+      handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, _]].numMaps)
+    val env = SparkEnv.get
+    handle match {
+      case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] =>
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          context.taskMemoryManager(),
+          env.shuffleMemoryManager,
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf)
+      case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] =>
+        new BypassMergeSortShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          bypassMergeSortHandle,
+          mapId,
+          context,
+          env.conf)
+      case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
+        new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
+    }
   }
 
   /** Remove a shuffle's metadata from the ShuffleManager. */
   override def unregisterShuffle(shuffleId: Int): Boolean = {
-    if (shuffleMapNumber.containsKey(shuffleId)) {
-      val numMaps = shuffleMapNumber.remove(shuffleId)
-      (0 until numMaps).map{ mapId =>
+    Option(numMapsForShuffle.remove(shuffleId)).foreach { numMaps =>
+      (0 until numMaps).foreach { mapId =>
         shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
       }
     }
     true
   }
 
-  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
-    indexShuffleBlockResolver
-  }
-
   /** Shut down this ShuffleManager. */
   override def stop(): Unit = {
     shuffleBlockResolver.stop()
   }
 }
 
+
+private[spark] object SortShuffleManager extends Logging {
+
+  /**
+   * The maximum number of shuffle output partitions that SortShuffleManager supports when
+   * buffering map outputs in a serialized form. This is an extreme defensive programming measure,
+   * since it's extremely unlikely that a single shuffle produces over 16 million output partitions.
+   * */
+  val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE =
+    PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+  /**
+   * Helper method for determining whether a shuffle should use an optimized serialized shuffle
+   * path or whether it should fall back to the original path that operates on deserialized objects.
+   */
+  def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
+    val shufId = dependency.shuffleId
+    val numPartitions = dependency.partitioner.numPartitions
+    val serializer = Serializer.getSerializer(dependency.serializer)
+    if (!serializer.supportsRelocationOfSerializedObjects) {
+      log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
+        s"${serializer.getClass.getName}, does not support object relocation")
+      false
+    } else if (dependency.aggregator.isDefined) {
+      log.debug(
+        s"Can't use serialized shuffle for shuffle $shufId because an aggregator is defined")
+      false
+    } else if (numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE) {
+      log.debug(s"Can't use serialized shuffle for shuffle $shufId because it has more than " +
+        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE partitions")
+      false
+    } else {
+      log.debug(s"Can use serialized shuffle for shuffle $shufId")
+      true
+    }
+  }
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * serialized shuffle.
+ */
+private[spark] class SerializedShuffleHandle[K, V](
+  shuffleId: Int,
+  numMaps: Int,
+  dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the
+ * bypass merge sort shuffle path.
+ */
+private[spark] class BypassMergeSortShuffleHandle[K, V](
+  shuffleId: Int,
+  numMaps: Int,
+  dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 5865e76..bbd9c1a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
 import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
@@ -36,7 +35,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private val blockManager = SparkEnv.get.blockManager
 
-  private var sorter: SortShuffleFileWriter[K, V] = null
+  private var sorter: ExternalSorter[K, V, _] = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -54,15 +53,6 @@ private[spark] class SortShuffleWriter[K, V, C](
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
       new ExternalSorter[K, V, C](
         dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
-    } else if (SortShuffleWriter.shouldBypassMergeSort(
-        SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
-      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
-      // need local aggregation and sorting, write numPartitions files directly and just concatenate
-      // them at the end. This avoids doing serialization and deserialization twice to merge
-      // together the spilled files, which would happen with the normal code path. The downside is
-      // having multiple files open at a time and thus more memory allocated to buffers.
-      new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
-        writeMetrics, Serializer.getSerializer(dep.serializer))
     } else {
       // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
       // care whether the keys get sorted in each partition; that will be done on the reduce side
@@ -111,12 +101,14 @@ private[spark] class SortShuffleWriter[K, V, C](
 }
 
 private[spark] object SortShuffleWriter {
-  def shouldBypassMergeSort(
-      conf: SparkConf,
-      numPartitions: Int,
-      aggregator: Option[Aggregator[_, _, _]],
-      keyOrdering: Option[Ordering[_]]): Boolean = {
-    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-    numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+  def shouldBypassMergeSort(conf: SparkConf, dep: ShuffleDependency[_, _, _]): Boolean = {
+    // We cannot bypass sorting if we need to do map-side aggregation.
+    if (dep.mapSideCombine) {
+      require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
+      false
+    } else {
+      val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+      dep.partitioner.numPartitions <= bypassMergeThreshold
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
deleted file mode 100644
index 75f22f6..0000000
--- a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe
-
-import java.util.Collections
-import java.util.concurrent.ConcurrentHashMap
-
-import org.apache.spark._
-import org.apache.spark.serializer.Serializer
-import org.apache.spark.shuffle._
-import org.apache.spark.shuffle.sort.SortShuffleManager
-
-/**
- * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
- */
-private[spark] class UnsafeShuffleHandle[K, V](
-    shuffleId: Int,
-    numMaps: Int,
-    dependency: ShuffleDependency[K, V, V])
-  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
-}
-
-private[spark] object UnsafeShuffleManager extends Logging {
-
-  /**
-   * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
-   */
-  val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
-
-  /**
-   * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
-   * path or whether it should fall back to the original sort-based shuffle.
-   */
-  def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
-    val shufId = dependency.shuffleId
-    val serializer = Serializer.getSerializer(dependency.serializer)
-    if (!serializer.supportsRelocationOfSerializedObjects) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
-        s"${serializer.getClass.getName}, does not support object relocation")
-      false
-    } else if (dependency.aggregator.isDefined) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
-      false
-    } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
-      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
-        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
-      false
-    } else {
-      log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
-      true
-    }
-  }
-}
-
-/**
- * A shuffle implementation that uses directly-managed memory to implement several performance
- * optimizations for certain types of shuffles. In cases where the new performance optimizations
- * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
- * shuffles.
- *
- * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
- *
- *  - The shuffle dependency specifies no aggregation or output ordering.
- *  - The shuffle serializer supports relocation of serialized values (this is currently supported
- *    by KryoSerializer and Spark SQL's custom serializers).
- *  - The shuffle produces fewer than 16777216 output partitions.
- *  - No individual record is larger than 128 MB when serialized.
- *
- * In addition, extra spill-merging optimizations are automatically applied when the shuffle
- * compression codec supports concatenation of serialized streams. This is currently supported by
- * Spark's LZF serializer.
- *
- * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
- * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
- * written to a single map output file. Reducers fetch contiguous regions of this file in order to
- * read their portion of the map output. In cases where the map output data is too large to fit in
- * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
- * to produce the final output file.
- *
- * UnsafeShuffleManager optimizes this process in several ways:
- *
- *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
- *    consumption and GC overheads. This optimization requires the record serializer to have certain
- *    properties to allow serialized records to be re-ordered without requiring deserialization.
- *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
- *
- *  - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
- *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
- *    record in the sorting array, this fits more of the array into cache.
- *
- *  - The spill merging procedure operates on blocks of serialized records that belong to the same
- *    partition and does not need to deserialize records during the merge.
- *
- *  - When the spill compression codec supports concatenation of compressed data, the spill merge
- *    simply concatenates the serialized and compressed spill partitions to produce the final output
- *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
- *    and avoids the need to allocate decompression or copying buffers during the merge.
- *
- * For more details on UnsafeShuffleManager's design, see SPARK-7081.
- */
-private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
-
-  if (!conf.getBoolean("spark.shuffle.spill", true)) {
-    logWarning(
-      "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
-      "manager; its optimized shuffles will continue to spill to disk when necessary.")
-  }
-
-  private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
-  private[this] val shufflesThatFellBackToSortShuffle =
-    Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
-  private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
-
-  /**
-   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
-   */
-  override def registerShuffle[K, V, C](
-      shuffleId: Int,
-      numMaps: Int,
-      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
-    if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
-      new UnsafeShuffleHandle[K, V](
-        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
-    } else {
-      new BaseShuffleHandle(shuffleId, numMaps, dependency)
-    }
-  }
-
-  /**
-   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
-   * Called on executors by reduce tasks.
-   */
-  override def getReader[K, C](
-      handle: ShuffleHandle,
-      startPartition: Int,
-      endPartition: Int,
-      context: TaskContext): ShuffleReader[K, C] = {
-    sortShuffleManager.getReader(handle, startPartition, endPartition, context)
-  }
-
-  /** Get a writer for a given partition. Called on executors by map tasks. */
-  override def getWriter[K, V](
-      handle: ShuffleHandle,
-      mapId: Int,
-      context: TaskContext): ShuffleWriter[K, V] = {
-    handle match {
-      case unsafeShuffleHandle: UnsafeShuffleHandle[K @unchecked, V @unchecked] =>
-        numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
-        val env = SparkEnv.get
-        new UnsafeShuffleWriter(
-          env.blockManager,
-          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
-          context.taskMemoryManager(),
-          env.shuffleMemoryManager,
-          unsafeShuffleHandle,
-          mapId,
-          context,
-          env.conf)
-      case other =>
-        shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
-        sortShuffleManager.getWriter(handle, mapId, context)
-    }
-  }
-
-  /** Remove a shuffle's metadata from the ShuffleManager. */
-  override def unregisterShuffle(shuffleId: Int): Boolean = {
-    if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
-      sortShuffleManager.unregisterShuffle(shuffleId)
-    } else {
-      Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
-        (0 until numMaps).foreach { mapId =>
-          shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
-        }
-      }
-      true
-    }
-  }
-
-  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
-    sortShuffleManager.shuffleBlockResolver
-  }
-
-  /** Shut down this ShuffleManager. */
-  override def stop(): Unit = {
-    sortShuffleManager.stop()
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
deleted file mode 100644
index ae60f3b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
+++ /dev/null
@@ -1,146 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.OutputStream
-
-import scala.collection.mutable.ArrayBuffer
-
-/**
- * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
- * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
- * of memory and needing to copy the full contents. The disadvantage is that the contents don't
- * occupy a contiguous segment of memory.
- */
-private[spark] class ChainedBuffer(chunkSize: Int) {
-
-  private val chunkSizeLog2: Int = java.lang.Long.numberOfTrailingZeros(
-    java.lang.Long.highestOneBit(chunkSize))
-  assert((1 << chunkSizeLog2) == chunkSize,
-    s"ChainedBuffer chunk size $chunkSize must be a power of two")
-  private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
-  private var _size: Long = 0
-
-  /**
-   * Feed bytes from this buffer into a DiskBlockObjectWriter.
-   *
-   * @param pos Offset in the buffer to read from.
-   * @param os OutputStream to read into.
-   * @param len Number of bytes to read.
-   */
-  def read(pos: Long, os: OutputStream, len: Int): Unit = {
-    if (pos + len > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
-    }
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toRead: Int = math.min(len - written, chunkSize - posInChunk)
-      os.write(chunks(chunkIndex), posInChunk, toRead)
-      written += toRead
-      chunkIndex += 1
-      posInChunk = 0
-    }
-  }
-
-  /**
-   * Read bytes from this buffer into a byte array.
-   *
-   * @param pos Offset in the buffer to read from.
-   * @param bytes Byte array to read into.
-   * @param offs Offset in the byte array to read to.
-   * @param len Number of bytes to read.
-   */
-  def read(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (pos + len > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Read of $len bytes at position $pos would go past size of buffer")
-    }
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toRead: Int = math.min(len - written, chunkSize - posInChunk)
-      System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
-      written += toRead
-      chunkIndex += 1
-      posInChunk = 0
-    }
-  }
-
-  /**
-   * Write bytes from a byte array into this buffer.
-   *
-   * @param pos Offset in the buffer to write to.
-   * @param bytes Byte array to write from.
-   * @param offs Offset in the byte array to write from.
-   * @param len Number of bytes to write.
-   */
-  def write(pos: Long, bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (pos > _size) {
-      throw new IndexOutOfBoundsException(
-        s"Write at position $pos starts after end of buffer ${_size}")
-    }
-    // Grow if needed
-    val endChunkIndex: Int = ((pos + len - 1) >> chunkSizeLog2).toInt
-    while (endChunkIndex >= chunks.length) {
-      chunks += new Array[Byte](chunkSize)
-    }
-
-    var chunkIndex: Int = (pos >> chunkSizeLog2).toInt
-    var posInChunk: Int = (pos - (chunkIndex.toLong << chunkSizeLog2)).toInt
-    var written: Int = 0
-    while (written < len) {
-      val toWrite: Int = math.min(len - written, chunkSize - posInChunk)
-      System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
-      written += toWrite
-      chunkIndex += 1
-      posInChunk = 0
-    }
-
-    _size = math.max(_size, pos + len)
-  }
-
-  /**
-   * Total size of buffer that can be written to without allocating additional memory.
-   */
-  def capacity: Long = chunks.size.toLong * chunkSize
-
-  /**
-   * Size of the logical buffer.
-   */
-  def size: Long = _size
-}
-
-/**
- * Output stream that writes to a ChainedBuffer.
- */
-private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
-  private var pos: Long = 0
-
-  override def write(b: Int): Unit = {
-    throw new UnsupportedOperationException()
-  }
-
-  override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
-    chainedBuffer.write(pos, bytes, offs, len)
-    pos += len
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 749be34..c48c453 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -29,7 +29,6 @@ import com.google.common.io.ByteStreams
 import org.apache.spark._
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
 import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
 
 /**
@@ -69,8 +68,8 @@ import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter}
  * At a high level, this class works internally as follows:
  *
  * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
- *   we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
- *   don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ *   we want to combine by key, or a PartitionedPairBuffer if we don't.
+ *   Inside these buffers, we sort elements by partition ID and then possibly also by key.
  *   To avoid calling the partitioner multiple times with each key, we store the partition ID
  *   alongside each record.
  *
@@ -93,8 +92,7 @@ private[spark] class ExternalSorter[K, V, C](
     ordering: Option[Ordering[K]] = None,
     serializer: Option[Serializer] = None)
   extends Logging
-  with Spillable[WritablePartitionedPairCollection[K, C]]
-  with SortShuffleFileWriter[K, V] {
+  with Spillable[WritablePartitionedPairCollection[K, C]] {
 
   private val conf = SparkEnv.get.conf
 
@@ -104,13 +102,6 @@ private[spark] class ExternalSorter[K, V, C](
     if (shouldPartition) partitioner.get.getPartition(key) else 0
   }
 
-  // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
-  // As a sanity check, make sure that we're not handling a shuffle which should use that path.
-  if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
-    throw new IllegalArgumentException("ExternalSorter should not be used to handle "
-      + " a sort that the BypassMergeSortShuffleWriter should handle")
-  }
-
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
   private val ser = Serializer.getSerializer(serializer)
@@ -128,23 +119,11 @@ private[spark] class ExternalSorter[K, V, C](
   // grow internal data structures by growing + copying every time the number of objects doubles.
   private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
 
-  private val useSerializedPairBuffer =
-    ordering.isEmpty &&
-      conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
-      ser.supportsRelocationOfSerializedObjects
-  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
-  private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
-    if (useSerializedPairBuffer) {
-      new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
-    } else {
-      new PartitionedPairBuffer[K, C]
-    }
-  }
   // Data structures to store in-memory objects before we spill. Depending on whether we have an
   // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
   // store them in an array buffer.
   private var map = new PartitionedAppendOnlyMap[K, C]
-  private var buffer = newBuffer()
+  private var buffer = new PartitionedPairBuffer[K, C]
 
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
@@ -192,7 +171,7 @@ private[spark] class ExternalSorter[K, V, C](
    */
   private[spark] def numSpills: Int = spills.size
 
-  override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
+  def insertAll(records: Iterator[Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -236,7 +215,7 @@ private[spark] class ExternalSorter[K, V, C](
     } else {
       estimatedSize = buffer.estimateSize()
       if (maybeSpill(buffer, estimatedSize)) {
-        buffer = newBuffer()
+        buffer = new PartitionedPairBuffer[K, C]
       }
     }
 
@@ -659,7 +638,7 @@ private[spark] class ExternalSorter[K, V, C](
    * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
-  override def writePartitionedFile(
+  def writePartitionedFile(
       blockId: BlockId,
       context: TaskContext,
       outputFile: File): Array[Long] = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[4/4] spark git commit: [SPARK-10708] Consolidate sort shuffle implementations

Posted by jo...@apache.org.
[SPARK-10708] Consolidate sort shuffle implementations

There's a lot of duplication between SortShuffleManager and UnsafeShuffleManager. Given that these now provide the same set of functionality, now that UnsafeShuffleManager supports large records, I think that we should replace SortShuffleManager's serialized shuffle implementation with UnsafeShuffleManager's and should merge the two managers together.

Author: Josh Rosen <jo...@databricks.com>

Closes #8829 from JoshRosen/consolidate-sort-shuffle-implementations.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f6d06adf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f6d06adf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f6d06adf

Branch: refs/heads/master
Commit: f6d06adf05afa9c5386dc2396c94e7a98730289f
Parents: 94e2064
Author: Josh Rosen <jo...@databricks.com>
Authored: Thu Oct 22 09:46:30 2015 -0700
Committer: Josh Rosen <jo...@databricks.com>
Committed: Thu Oct 22 09:46:30 2015 -0700

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      | 106 +++-
 .../spark/shuffle/sort/PackedRecordPointer.java |  92 +++
 .../shuffle/sort/ShuffleExternalSorter.java     | 491 ++++++++++++++++
 .../shuffle/sort/ShuffleInMemorySorter.java     | 124 ++++
 .../shuffle/sort/ShuffleSortDataFormat.java     |  67 +++
 .../shuffle/sort/SortShuffleFileWriter.java     |  53 --
 .../apache/spark/shuffle/sort/SpillInfo.java    |  37 ++
 .../spark/shuffle/sort/UnsafeShuffleWriter.java | 489 ++++++++++++++++
 .../shuffle/unsafe/PackedRecordPointer.java     |  92 ---
 .../apache/spark/shuffle/unsafe/SpillInfo.java  |  37 --
 .../unsafe/UnsafeShuffleExternalSorter.java     | 479 ----------------
 .../unsafe/UnsafeShuffleInMemorySorter.java     | 124 ----
 .../unsafe/UnsafeShuffleSortDataFormat.java     |  67 ---
 .../shuffle/unsafe/UnsafeShuffleWriter.java     | 489 ----------------
 .../main/scala/org/apache/spark/SparkEnv.scala  |   2 +-
 .../spark/shuffle/sort/SortShuffleManager.scala | 175 +++++-
 .../spark/shuffle/sort/SortShuffleWriter.scala  |  28 +-
 .../shuffle/unsafe/UnsafeShuffleManager.scala   | 202 -------
 .../spark/util/collection/ChainedBuffer.scala   | 146 -----
 .../spark/util/collection/ExternalSorter.scala  |  35 +-
 .../PartitionedSerializedPairBuffer.scala       | 273 ---------
 .../shuffle/sort/PackedRecordPointerSuite.java  | 102 ++++
 .../sort/ShuffleInMemorySorterSuite.java        | 124 ++++
 .../shuffle/sort/UnsafeShuffleWriterSuite.java  | 560 +++++++++++++++++++
 .../unsafe/PackedRecordPointerSuite.java        | 101 ----
 .../UnsafeShuffleInMemorySorterSuite.java       | 124 ----
 .../unsafe/UnsafeShuffleWriterSuite.java        | 560 -------------------
 .../org/apache/spark/SortShuffleSuite.scala     |  65 +++
 .../spark/scheduler/DAGSchedulerSuite.scala     |   6 +-
 .../BypassMergeSortShuffleWriterSuite.scala     |  64 ++-
 .../shuffle/sort/SortShuffleManagerSuite.scala  | 131 +++++
 .../shuffle/sort/SortShuffleWriterSuite.scala   |  45 --
 .../unsafe/UnsafeShuffleManagerSuite.scala      | 129 -----
 .../shuffle/unsafe/UnsafeShuffleSuite.scala     | 102 ----
 .../util/collection/ChainedBufferSuite.scala    | 144 -----
 .../PartitionedSerializedPairBufferSuite.scala  | 148 -----
 docs/configuration.md                           |   7 +-
 project/MimaExcludes.scala                      |   9 +-
 .../apache/spark/sql/execution/Exchange.scala   |  23 +-
 .../execution/UnsafeRowSerializerSuite.scala    |   9 +-
 40 files changed, 2600 insertions(+), 3461 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index f5d80bb..ee82d67 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -21,21 +21,30 @@ import java.io.File;
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
 import java.io.IOException;
+import javax.annotation.Nullable;
 
+import scala.None$;
+import scala.Option;
 import scala.Product2;
 import scala.Tuple2;
 import scala.collection.Iterator;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.io.Closeables;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.spark.Partitioner;
+import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.*;
 import org.apache.spark.util.Utils;
 
@@ -62,7 +71,7 @@ import org.apache.spark.util.Utils;
  * <p>
  * There have been proposals to completely remove this code path; see SPARK-6026 for details.
  */
-final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
 
   private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
 
@@ -72,31 +81,52 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
   private final BlockManager blockManager;
   private final Partitioner partitioner;
   private final ShuffleWriteMetrics writeMetrics;
+  private final int shuffleId;
+  private final int mapId;
   private final Serializer serializer;
+  private final IndexShuffleBlockResolver shuffleBlockResolver;
 
   /** Array of file writers, one for each partition */
   private DiskBlockObjectWriter[] partitionWriters;
+  @Nullable private MapStatus mapStatus;
+  private long[] partitionLengths;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with success = true
+   * and then call stop() with success = false if they get an exception, we want to make sure
+   * we don't try deleting files, etc twice.
+   */
+  private boolean stopping = false;
 
   public BypassMergeSortShuffleWriter(
-      SparkConf conf,
       BlockManager blockManager,
-      Partitioner partitioner,
-      ShuffleWriteMetrics writeMetrics,
-      Serializer serializer) {
+      IndexShuffleBlockResolver shuffleBlockResolver,
+      BypassMergeSortShuffleHandle<K, V> handle,
+      int mapId,
+      TaskContext taskContext,
+      SparkConf conf) {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
     this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
     this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
-    this.numPartitions = partitioner.numPartitions();
     this.blockManager = blockManager;
-    this.partitioner = partitioner;
-    this.writeMetrics = writeMetrics;
-    this.serializer = serializer;
+    final ShuffleDependency<K, V, V> dep = handle.dependency();
+    this.mapId = mapId;
+    this.shuffleId = dep.shuffleId();
+    this.partitioner = dep.partitioner();
+    this.numPartitions = partitioner.numPartitions();
+    this.writeMetrics = new ShuffleWriteMetrics();
+    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.serializer = Serializer.getSerializer(dep.serializer());
+    this.shuffleBlockResolver = shuffleBlockResolver;
   }
 
   @Override
-  public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
     assert (partitionWriters == null);
     if (!records.hasNext()) {
+      partitionLengths = new long[numPartitions];
+      shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+      mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
       return;
     }
     final SerializerInstance serInstance = serializer.newInstance();
@@ -124,13 +154,24 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
     for (DiskBlockObjectWriter writer : partitionWriters) {
       writer.commitAndClose();
     }
+
+    partitionLengths =
+      writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
+    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
   }
 
-  @Override
-  public long[] writePartitionedFile(
-      BlockId blockId,
-      TaskContext context,
-      File outputFile) throws IOException {
+  @VisibleForTesting
+  long[] getPartitionLengths() {
+    return partitionLengths;
+  }
+
+  /**
+   * Concatenate all of the per-partition files into a single combined file.
+   *
+   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker).
+   */
+  private long[] writePartitionedFile(File outputFile) throws IOException {
     // Track location of the partition starts in the output file
     final long[] lengths = new long[numPartitions];
     if (partitionWriters == null) {
@@ -165,18 +206,33 @@ final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<
   }
 
   @Override
-  public void stop() throws IOException {
-    if (partitionWriters != null) {
-      try {
-        for (DiskBlockObjectWriter writer : partitionWriters) {
-          // This method explicitly does _not_ throw exceptions:
-          File file = writer.revertPartialWritesAndClose();
-          if (!file.delete()) {
-            logger.error("Error while deleting file {}", file.getAbsolutePath());
+  public Option<MapStatus> stop(boolean success) {
+    if (stopping) {
+      return None$.empty();
+    } else {
+      stopping = true;
+      if (success) {
+        if (mapStatus == null) {
+          throw new IllegalStateException("Cannot call stop(true) without having called write()");
+        }
+        return Option.apply(mapStatus);
+      } else {
+        // The map task failed, so delete our output data.
+        if (partitionWriters != null) {
+          try {
+            for (DiskBlockObjectWriter writer : partitionWriters) {
+              // This method explicitly does _not_ throw exceptions:
+              File file = writer.revertPartialWritesAndClose();
+              if (!file.delete()) {
+                logger.error("Error while deleting file {}", file.getAbsolutePath());
+              }
+            }
+          } finally {
+            partitionWriters = null;
           }
         }
-      } finally {
-        partitionWriters = null;
+        shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+        return None$.empty();
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
new file mode 100644
index 0000000..c117119
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ * <p>
+ * Within the long, the data is laid out as follows:
+ * <pre>
+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * </pre>
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ * <p>
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+  static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27;  // 128 megabytes
+
+  /**
+   * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+   */
+  static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1;  // 16777215
+
+  /** Bit mask for the lower 40 bits of a long. */
+  private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+  /** Bit mask for the upper 24 bits of a long */
+  private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+  /** Bit mask for the lower 27 bits of a long. */
+  private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+  /** Bit mask for the lower 51 bits of a long. */
+  private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+  /** Bit mask for the upper 13 bits of a long */
+  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+  /**
+   * Pack a record address and partition id into a single word.
+   *
+   * @param recordPointer a record pointer encoded by TaskMemoryManager.
+   * @param partitionId a shuffle partition id (maximum value of 2^24).
+   * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+   */
+  public static long packPointer(long recordPointer, int partitionId) {
+    assert (partitionId <= MAXIMUM_PARTITION_ID);
+    // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+    // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+    final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+    final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+    return (((long) partitionId) << 40) | compressedAddress;
+  }
+
+  private long packedRecordPointer;
+
+  public void set(long packedRecordPointer) {
+    this.packedRecordPointer = packedRecordPointer;
+  }
+
+  public int getPartitionId() {
+    return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+  }
+
+  public long getRecordPointer() {
+    final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+    final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+    return pageNumber | offsetInPage;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
new file mode 100644
index 0000000..85fdaa8
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -0,0 +1,491 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import javax.annotation.Nullable;
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.DummySerializerInstance;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.TempShuffleBlockId;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.array.ByteArrayMethods;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ * <p>
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link ShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ * <p>
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class ShuffleExternalSorter {
+
+  private final Logger logger = LoggerFactory.getLogger(ShuffleExternalSorter.class);
+
+  @VisibleForTesting
+  static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+
+  private final int initialSize;
+  private final int numPartitions;
+  private final int pageSizeBytes;
+  @VisibleForTesting
+  final int maxRecordSizeBytes;
+  private final TaskMemoryManager taskMemoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final BlockManager blockManager;
+  private final TaskContext taskContext;
+  private final ShuffleWriteMetrics writeMetrics;
+  private long numRecordsInsertedSinceLastSpill = 0;
+
+  /** Force this sorter to spill when there are this many elements in memory. For testing only */
+  private final long numElementsForSpillThreshold;
+
+  /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+  private final int fileBufferSizeBytes;
+
+  /**
+   * Memory pages that hold the records being sorted. The pages in this list are freed when
+   * spilling, although in principle we could recycle these pages across spills (on the other hand,
+   * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+   * itself).
+   */
+  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+  private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+
+  /** Peak memory used by this sorter so far, in bytes. **/
+  private long peakMemoryUsedBytes;
+
+  // These variables are reset after spilling:
+  @Nullable private ShuffleInMemorySorter inMemSorter;
+  @Nullable private MemoryBlock currentPage = null;
+  private long currentPagePosition = -1;
+  private long freeSpaceInCurrentPage = 0;
+
+  public ShuffleExternalSorter(
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      BlockManager blockManager,
+      TaskContext taskContext,
+      int initialSize,
+      int numPartitions,
+      SparkConf conf,
+      ShuffleWriteMetrics writeMetrics) throws IOException {
+    this.taskMemoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.blockManager = blockManager;
+    this.taskContext = taskContext;
+    this.initialSize = initialSize;
+    this.peakMemoryUsedBytes = initialSize;
+    this.numPartitions = numPartitions;
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+    this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.numElementsForSpillThreshold =
+      conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
+    this.pageSizeBytes = (int) Math.min(
+      PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, shuffleMemoryManager.pageSizeBytes());
+    this.maxRecordSizeBytes = pageSizeBytes - 4;
+    this.writeMetrics = writeMetrics;
+    initializeForWriting();
+
+    // preserve first page to ensure that we have at least one page to work with. Otherwise,
+    // other operators in the same task may starve this sorter (SPARK-9709).
+    acquireNewPageIfNecessary(pageSizeBytes);
+  }
+
+  /**
+   * Allocates new sort data structures. Called when creating the sorter and after each spill.
+   */
+  private void initializeForWriting() throws IOException {
+    // TODO: move this sizing calculation logic into a static method of sorter:
+    final long memoryRequested = initialSize * 8L;
+    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+    if (memoryAcquired != memoryRequested) {
+      shuffleMemoryManager.release(memoryAcquired);
+      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+    }
+
+    this.inMemSorter = new ShuffleInMemorySorter(initialSize);
+    numRecordsInsertedSinceLastSpill = 0;
+  }
+
+  /**
+   * Sorts the in-memory records and writes the sorted records to an on-disk file.
+   * This method does not free the sort data structures.
+   *
+   * @param isLastFile if true, this indicates that we're writing the final output file and that the
+   *                   bytes written should be counted towards shuffle spill metrics rather than
+   *                   shuffle write metrics.
+   */
+  private void writeSortedFile(boolean isLastFile) throws IOException {
+
+    final ShuffleWriteMetrics writeMetricsToUse;
+
+    if (isLastFile) {
+      // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+      writeMetricsToUse = writeMetrics;
+    } else {
+      // We're spilling, so bytes written should be counted towards spill rather than write.
+      // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+      // them towards shuffle bytes written.
+      writeMetricsToUse = new ShuffleWriteMetrics();
+    }
+
+    // This call performs the actual sort.
+    final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
+      inMemSorter.getSortedIterator();
+
+    // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+    // after SPARK-5581 is fixed.
+    DiskBlockObjectWriter writer;
+
+    // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+    // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+    // data through a byte array. This array does not need to be large enough to hold a single
+    // record;
+    final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+    // Because this output will be read during shuffle, its compression codec must be controlled by
+    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+    // createTempShuffleBlock here; see SPARK-3426 for more details.
+    final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
+      blockManager.diskBlockManager().createTempShuffleBlock();
+    final File file = spilledFileInfo._2();
+    final TempShuffleBlockId blockId = spilledFileInfo._1();
+    final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+    // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+    // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+    // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+    // around this, we pass a dummy no-op serializer.
+    final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+    writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+    int currentPartition = -1;
+    while (sortedRecords.hasNext()) {
+      sortedRecords.loadNext();
+      final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+      assert (partition >= currentPartition);
+      if (partition != currentPartition) {
+        // Switch to the new partition
+        if (currentPartition != -1) {
+          writer.commitAndClose();
+          spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        }
+        currentPartition = partition;
+        writer =
+          blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+      }
+
+      final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+      final Object recordPage = taskMemoryManager.getPage(recordPointer);
+      final long recordOffsetInPage = taskMemoryManager.getOffsetInPage(recordPointer);
+      int dataRemaining = Platform.getInt(recordPage, recordOffsetInPage);
+      long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+      while (dataRemaining > 0) {
+        final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+        Platform.copyMemory(
+          recordPage, recordReadPosition, writeBuffer, Platform.BYTE_ARRAY_OFFSET, toTransfer);
+        writer.write(writeBuffer, 0, toTransfer);
+        recordReadPosition += toTransfer;
+        dataRemaining -= toTransfer;
+      }
+      writer.recordWritten();
+    }
+
+    if (writer != null) {
+      writer.commitAndClose();
+      // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+      // then the file might be empty. Note that it might be better to avoid calling
+      // writeSortedFile() in that case.
+      if (currentPartition != -1) {
+        spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        spills.add(spillInfo);
+      }
+    }
+
+    if (!isLastFile) {  // i.e. this is a spill file
+      // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+      // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+      // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+      // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+      // counter at a higher-level, then the in-progress metrics for records written and bytes
+      // written would get out of sync.
+      //
+      // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+      // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+      // metrics to the true write metrics here. The reason for performing this copying is so that
+      // we can avoid reporting spilled bytes as shuffle write bytes.
+      //
+      // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+      // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+      // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+      writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+      taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+    }
+  }
+
+  /**
+   * Sort and spill the current records in response to memory pressure.
+   */
+  @VisibleForTesting
+  void spill() throws IOException {
+    logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+      Thread.currentThread().getId(),
+      Utils.bytesToString(getMemoryUsage()),
+      spills.size(),
+      spills.size() > 1 ? " times" : " time");
+
+    writeSortedFile(false);
+    final long inMemSorterMemoryUsage = inMemSorter.getMemoryUsage();
+    inMemSorter = null;
+    shuffleMemoryManager.release(inMemSorterMemoryUsage);
+    final long spillSize = freeMemory();
+    taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+    initializeForWriting();
+  }
+
+  private long getMemoryUsage() {
+    long totalPageSize = 0;
+    for (MemoryBlock page : allocatedPages) {
+      totalPageSize += page.size();
+    }
+    return ((inMemSorter == null) ? 0 : inMemSorter.getMemoryUsage()) + totalPageSize;
+  }
+
+  private void updatePeakMemoryUsed() {
+    long mem = getMemoryUsage();
+    if (mem > peakMemoryUsedBytes) {
+      peakMemoryUsedBytes = mem;
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
+  }
+
+  private long freeMemory() {
+    updatePeakMemoryUsed();
+    long memoryFreed = 0;
+    for (MemoryBlock block : allocatedPages) {
+      taskMemoryManager.freePage(block);
+      shuffleMemoryManager.release(block.size());
+      memoryFreed += block.size();
+    }
+    allocatedPages.clear();
+    currentPage = null;
+    currentPagePosition = -1;
+    freeSpaceInCurrentPage = 0;
+    return memoryFreed;
+  }
+
+  /**
+   * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+   */
+  public void cleanupResources() {
+    freeMemory();
+    for (SpillInfo spill : spills) {
+      if (spill.file.exists() && !spill.file.delete()) {
+        logger.error("Unable to delete spill file {}", spill.file.getPath());
+      }
+    }
+    if (inMemSorter != null) {
+      shuffleMemoryManager.release(inMemSorter.getMemoryUsage());
+      inMemSorter = null;
+    }
+  }
+
+  /**
+   * Checks whether there is enough space to insert an additional record in to the sort pointer
+   * array and grows the array if additional space is required. If the required space cannot be
+   * obtained, then the in-memory data will be spilled to disk.
+   */
+  private void growPointerArrayIfNecessary() throws IOException {
+    assert(inMemSorter != null);
+    if (!inMemSorter.hasSpaceForAnotherRecord()) {
+      logger.debug("Attempting to expand sort pointer array");
+      final long oldPointerArrayMemoryUsage = inMemSorter.getMemoryUsage();
+      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+      if (memoryAcquired < memoryToGrowPointerArray) {
+        shuffleMemoryManager.release(memoryAcquired);
+        spill();
+      } else {
+        inMemSorter.expandPointerArray();
+        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+      }
+    }
+  }
+  
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+   * obtained.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size. This must be less than or equal to the page size (records
+   *                      that exceed the page size are handled via a different code path which uses
+   *                      special overflow pages).
+   */
+  private void acquireNewPageIfNecessary(int requiredSpace) throws IOException {
+    growPointerArrayIfNecessary();
+    if (requiredSpace > freeSpaceInCurrentPage) {
+      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+        freeSpaceInCurrentPage);
+      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+      // without using the free space at the end of the current page. We should also do this for
+      // BytesToBytesMap.
+      if (requiredSpace > pageSizeBytes) {
+        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+          pageSizeBytes + ")");
+      } else {
+        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+        if (memoryAcquired < pageSizeBytes) {
+          shuffleMemoryManager.release(memoryAcquired);
+          spill();
+          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+          if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+            throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
+          }
+        }
+        currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+        currentPagePosition = currentPage.getBaseOffset();
+        freeSpaceInCurrentPage = pageSizeBytes;
+        allocatedPages.add(currentPage);
+      }
+    }
+  }
+
+  /**
+   * Write a record to the shuffle sorter.
+   */
+  public void insertRecord(
+      Object recordBaseObject,
+      long recordBaseOffset,
+      int lengthInBytes,
+      int partitionId) throws IOException {
+
+    if (numRecordsInsertedSinceLastSpill > numElementsForSpillThreshold) {
+      spill();
+    }
+
+    growPointerArrayIfNecessary();
+    // Need 4 bytes to store the record length.
+    final int totalSpaceRequired = lengthInBytes + 4;
+
+    // --- Figure out where to insert the new record ----------------------------------------------
+
+    final MemoryBlock dataPage;
+    long dataPagePosition;
+    boolean useOverflowPage = totalSpaceRequired > pageSizeBytes;
+    if (useOverflowPage) {
+      long overflowPageSize = ByteArrayMethods.roundNumberOfBytesToNearestWord(totalSpaceRequired);
+      // The record is larger than the page size, so allocate a special overflow page just to hold
+      // that record.
+      final long memoryGranted = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+      if (memoryGranted != overflowPageSize) {
+        shuffleMemoryManager.release(memoryGranted);
+        spill();
+        final long memoryGrantedAfterSpill = shuffleMemoryManager.tryToAcquire(overflowPageSize);
+        if (memoryGrantedAfterSpill != overflowPageSize) {
+          shuffleMemoryManager.release(memoryGrantedAfterSpill);
+          throw new IOException("Unable to acquire " + overflowPageSize + " bytes of memory");
+        }
+      }
+      MemoryBlock overflowPage = taskMemoryManager.allocatePage(overflowPageSize);
+      allocatedPages.add(overflowPage);
+      dataPage = overflowPage;
+      dataPagePosition = overflowPage.getBaseOffset();
+    } else {
+      // The record is small enough to fit in a regular data page, but the current page might not
+      // have enough space to hold it (or no pages have been allocated yet).
+      acquireNewPageIfNecessary(totalSpaceRequired);
+      dataPage = currentPage;
+      dataPagePosition = currentPagePosition;
+      // Update bookkeeping information
+      freeSpaceInCurrentPage -= totalSpaceRequired;
+      currentPagePosition += totalSpaceRequired;
+    }
+    final Object dataPageBaseObject = dataPage.getBaseObject();
+
+    final long recordAddress =
+      taskMemoryManager.encodePageNumberAndOffset(dataPage, dataPagePosition);
+    Platform.putInt(dataPageBaseObject, dataPagePosition, lengthInBytes);
+    dataPagePosition += 4;
+    Platform.copyMemory(
+      recordBaseObject, recordBaseOffset, dataPageBaseObject, dataPagePosition, lengthInBytes);
+    assert(inMemSorter != null);
+    inMemSorter.insertRecord(recordAddress, partitionId);
+    numRecordsInsertedSinceLastSpill += 1;
+  }
+
+  /**
+   * Close the sorter, causing any buffered data to be sorted and written out to disk.
+   *
+   * @return metadata for the spill files written by this sorter. If no records were ever inserted
+   *         into this sorter, then this will return an empty array.
+   * @throws IOException
+   */
+  public SpillInfo[] closeAndGetSpills() throws IOException {
+    try {
+      if (inMemSorter != null) {
+        // Do not count the final file towards the spill count.
+        writeSortedFile(true);
+        freeMemory();
+      }
+      return spills.toArray(new SpillInfo[spills.size()]);
+    } catch (IOException e) {
+      cleanupResources();
+      throw e;
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
new file mode 100644
index 0000000..a8dee6c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class ShuffleInMemorySorter {
+
+  private final Sorter<PackedRecordPointer, long[]> sorter;
+  private static final class SortComparator implements Comparator<PackedRecordPointer> {
+    @Override
+    public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+      return left.getPartitionId() - right.getPartitionId();
+    }
+  }
+  private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+  /**
+   * An array of record pointers and partition ids that have been encoded by
+   * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+   * records.
+   */
+  private long[] pointerArray;
+
+  /**
+   * The position in the pointer array where new records can be inserted.
+   */
+  private int pointerArrayInsertPosition = 0;
+
+  public ShuffleInMemorySorter(int initialSize) {
+    assert (initialSize > 0);
+    this.pointerArray = new long[initialSize];
+    this.sorter = new Sorter<PackedRecordPointer, long[]>(ShuffleSortDataFormat.INSTANCE);
+  }
+
+  public void expandPointerArray() {
+    final long[] oldArray = pointerArray;
+    // Guard against overflow:
+    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+    pointerArray = new long[newLength];
+    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+  }
+
+  public boolean hasSpaceForAnotherRecord() {
+    return pointerArrayInsertPosition + 1 < pointerArray.length;
+  }
+
+  public long getMemoryUsage() {
+    return pointerArray.length * 8L;
+  }
+
+  /**
+   * Inserts a record to be sorted.
+   *
+   * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+   *                      certain pointer compression techniques used by the sorter, the sort can
+   *                      only operate on pointers that point to locations in the first
+   *                      {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+   * @param partitionId the partition id, which must be less than or equal to
+   *                    {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+   */
+  public void insertRecord(long recordPointer, int partitionId) {
+    if (!hasSpaceForAnotherRecord()) {
+      if (pointerArray.length == Integer.MAX_VALUE) {
+        throw new IllegalStateException("Sort pointer array has reached maximum size");
+      } else {
+        expandPointerArray();
+      }
+    }
+    pointerArray[pointerArrayInsertPosition] =
+        PackedRecordPointer.packPointer(recordPointer, partitionId);
+    pointerArrayInsertPosition++;
+  }
+
+  /**
+   * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+   */
+  public static final class ShuffleSorterIterator {
+
+    private final long[] pointerArray;
+    private final int numRecords;
+    final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+    private int position = 0;
+
+    public ShuffleSorterIterator(int numRecords, long[] pointerArray) {
+      this.numRecords = numRecords;
+      this.pointerArray = pointerArray;
+    }
+
+    public boolean hasNext() {
+      return position < numRecords;
+    }
+
+    public void loadNext() {
+      packedRecordPointer.set(pointerArray[position]);
+      position++;
+    }
+  }
+
+  /**
+   * Return an iterator over record pointers in sorted order.
+   */
+  public ShuffleSorterIterator getSortedIterator() {
+    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+    return new ShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
new file mode 100644
index 0000000..8a1e5ae
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class ShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+
+  public static final ShuffleSortDataFormat INSTANCE = new ShuffleSortDataFormat();
+
+  private ShuffleSortDataFormat() { }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos) {
+    // Since we re-use keys, this method shouldn't be called.
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public PackedRecordPointer newKey() {
+    return new PackedRecordPointer();
+  }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+    reuse.set(data[pos]);
+    return reuse;
+  }
+
+  @Override
+  public void swap(long[] data, int pos0, int pos1) {
+    final long temp = data[pos0];
+    data[pos0] = data[pos1];
+    data[pos1] = temp;
+  }
+
+  @Override
+  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+    dst[dstPos] = src[srcPos];
+  }
+
+  @Override
+  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+    System.arraycopy(src, srcPos, dst, dstPos, length);
+  }
+
+  @Override
+  public long[] allocate(int length) {
+    return new long[length];
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
deleted file mode 100644
index 656ea04..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.sort;
-
-import java.io.File;
-import java.io.IOException;
-
-import scala.Product2;
-import scala.collection.Iterator;
-
-import org.apache.spark.annotation.Private;
-import org.apache.spark.TaskContext;
-import org.apache.spark.storage.BlockId;
-
-/**
- * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
- */
-@Private
-public interface SortShuffleFileWriter<K, V> {
-
-  void insertAll(Iterator<Product2<K, V>> records) throws IOException;
-
-  /**
-   * Write all the data added into this shuffle sorter into a file in the disk store. This is
-   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
-   * binary files if we decided to avoid merge-sorting.
-   *
-   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
-   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
-   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
-   */
-  long[] writePartitionedFile(
-      BlockId blockId,
-      TaskContext context,
-      File outputFile) throws IOException;
-
-  void stop() throws IOException;
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
new file mode 100644
index 0000000..df9f7b7
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link ShuffleExternalSorter}.
+ */
+final class SpillInfo {
+  final long[] partitionLengths;
+  final File file;
+  final TempShuffleBlockId blockId;
+
+  public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+    this.partitionLengths = new long[numPartitions];
+    this.file = file;
+    this.blockId = blockId;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
new file mode 100644
index 0000000..e8f050c
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -0,0 +1,489 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import javax.annotation.Nullable;
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConverters;
+import scala.collection.immutable.Map;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+  private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+  @VisibleForTesting
+  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+  private final BlockManager blockManager;
+  private final IndexShuffleBlockResolver shuffleBlockResolver;
+  private final TaskMemoryManager memoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final SerializerInstance serializer;
+  private final Partitioner partitioner;
+  private final ShuffleWriteMetrics writeMetrics;
+  private final int shuffleId;
+  private final int mapId;
+  private final TaskContext taskContext;
+  private final SparkConf sparkConf;
+  private final boolean transferToEnabled;
+
+  @Nullable private MapStatus mapStatus;
+  @Nullable private ShuffleExternalSorter sorter;
+  private long peakMemoryUsedBytes = 0;
+
+  /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+  private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+    public MyByteArrayOutputStream(int size) { super(size); }
+    public byte[] getBuf() { return buf; }
+  }
+
+  private MyByteArrayOutputStream serBuffer;
+  private SerializationStream serOutputStream;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with success = true
+   * and then call stop() with success = false if they get an exception, we want to make sure
+   * we don't try deleting files, etc twice.
+   */
+  private boolean stopping = false;
+
+  public UnsafeShuffleWriter(
+      BlockManager blockManager,
+      IndexShuffleBlockResolver shuffleBlockResolver,
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      SerializedShuffleHandle<K, V> handle,
+      int mapId,
+      TaskContext taskContext,
+      SparkConf sparkConf) throws IOException {
+    final int numPartitions = handle.dependency().partitioner().numPartitions();
+    if (numPartitions > SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
+      throw new IllegalArgumentException(
+        "UnsafeShuffleWriter can only be used for shuffles with at most " +
+          SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE() + " reduce partitions");
+    }
+    this.blockManager = blockManager;
+    this.shuffleBlockResolver = shuffleBlockResolver;
+    this.memoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.mapId = mapId;
+    final ShuffleDependency<K, V, V> dep = handle.dependency();
+    this.shuffleId = dep.shuffleId();
+    this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+    this.partitioner = dep.partitioner();
+    this.writeMetrics = new ShuffleWriteMetrics();
+    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.taskContext = taskContext;
+    this.sparkConf = sparkConf;
+    this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+    open();
+  }
+
+  @VisibleForTesting
+  public int maxRecordSizeBytes() {
+    assert(sorter != null);
+    return sorter.maxRecordSizeBytes;
+  }
+
+  private void updatePeakMemoryUsed() {
+    // sorter can be null if this writer is closed
+    if (sorter != null) {
+      long mem = sorter.getPeakMemoryUsedBytes();
+      if (mem > peakMemoryUsedBytes) {
+        peakMemoryUsedBytes = mem;
+      }
+    }
+  }
+
+  /**
+   * Return the peak memory used so far, in bytes.
+   */
+  public long getPeakMemoryUsedBytes() {
+    updatePeakMemoryUsed();
+    return peakMemoryUsedBytes;
+  }
+
+  /**
+   * This convenience method should only be called in test code.
+   */
+  @VisibleForTesting
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
+    write(JavaConverters.asScalaIteratorConverter(records).asScala());
+  }
+
+  @Override
+  public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
+    // Keep track of success so we know if we encountered an exception
+    // We do this rather than a standard try/catch/re-throw to handle
+    // generic throwables.
+    boolean success = false;
+    try {
+      while (records.hasNext()) {
+        insertRecordIntoSorter(records.next());
+      }
+      closeAndWriteOutput();
+      success = true;
+    } finally {
+      if (sorter != null) {
+        try {
+          sorter.cleanupResources();
+        } catch (Exception e) {
+          // Only throw this error if we won't be masking another
+          // error.
+          if (success) {
+            throw e;
+          } else {
+            logger.error("In addition to a failure during writing, we failed during " +
+                         "cleanup.", e);
+          }
+        }
+      }
+    }
+  }
+
+  private void open() throws IOException {
+    assert (sorter == null);
+    sorter = new ShuffleExternalSorter(
+      memoryManager,
+      shuffleMemoryManager,
+      blockManager,
+      taskContext,
+      INITIAL_SORT_BUFFER_SIZE,
+      partitioner.numPartitions(),
+      sparkConf,
+      writeMetrics);
+    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+    serOutputStream = serializer.serializeStream(serBuffer);
+  }
+
+  @VisibleForTesting
+  void closeAndWriteOutput() throws IOException {
+    assert(sorter != null);
+    updatePeakMemoryUsed();
+    serBuffer = null;
+    serOutputStream = null;
+    final SpillInfo[] spills = sorter.closeAndGetSpills();
+    sorter = null;
+    final long[] partitionLengths;
+    try {
+      partitionLengths = mergeSpills(spills);
+    } finally {
+      for (SpillInfo spill : spills) {
+        if (spill.file.exists() && ! spill.file.delete()) {
+          logger.error("Error while deleting spill file {}", spill.file.getPath());
+        }
+      }
+    }
+    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+  }
+
+  @VisibleForTesting
+  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+    assert(sorter != null);
+    final K key = record._1();
+    final int partitionId = partitioner.getPartition(key);
+    serBuffer.reset();
+    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+    serOutputStream.flush();
+
+    final int serializedRecordSize = serBuffer.size();
+    assert (serializedRecordSize > 0);
+
+    sorter.insertRecord(
+      serBuffer.getBuf(), Platform.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+  }
+
+  @VisibleForTesting
+  void forceSorterToSpill() throws IOException {
+    assert (sorter != null);
+    sorter.spill();
+  }
+
+  /**
+   * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+   * number of spills and the IO compression codec.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+    final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+    final boolean fastMergeEnabled =
+      sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+    final boolean fastMergeIsSupported =
+      !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+    try {
+      if (spills.length == 0) {
+        new FileOutputStream(outputFile).close(); // Create an empty file
+        return new long[partitioner.numPartitions()];
+      } else if (spills.length == 1) {
+        // Here, we don't need to perform any metrics updates because the bytes written to this
+        // output file would have already been counted as shuffle bytes written.
+        Files.move(spills[0].file, outputFile);
+        return spills[0].partitionLengths;
+      } else {
+        final long[] partitionLengths;
+        // There are multiple spills to merge, so none of these spill files' lengths were counted
+        // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+        // then the final output file's size won't necessarily be equal to the sum of the spill
+        // files' sizes. To guard against this case, we look at the output file's actual size when
+        // computing shuffle bytes written.
+        //
+        // We allow the individual merge methods to report their own IO times since different merge
+        // strategies use different IO techniques.  We count IO during merge towards the shuffle
+        // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+        // branch in ExternalSorter.
+        if (fastMergeEnabled && fastMergeIsSupported) {
+          // Compression is disabled or we are using an IO compression codec that supports
+          // decompression of concatenated compressed streams, so we can perform a fast spill merge
+          // that doesn't need to interpret the spilled bytes.
+          if (transferToEnabled) {
+            logger.debug("Using transferTo-based fast merge");
+            partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+          } else {
+            logger.debug("Using fileStream-based fast merge");
+            partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+          }
+        } else {
+          logger.debug("Using slow merge");
+          partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+        }
+        // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+        // in-memory records, we write out the in-memory records to a file but do not count that
+        // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+        // to be counted as shuffle write, but this will lead to double-counting of the final
+        // SpillInfo's bytes.
+        writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+        writeMetrics.incShuffleBytesWritten(outputFile.length());
+        return partitionLengths;
+      }
+    } catch (IOException e) {
+      if (outputFile.exists() && !outputFile.delete()) {
+        logger.error("Unable to delete output file {}", outputFile.getPath());
+      }
+      throw e;
+    }
+  }
+
+  /**
+   * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+   * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+   * cases where the IO compression codec does not support concatenation of compressed data, or in
+   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+   * kernel bugs.
+   *
+   * @param spills the spills to merge.
+   * @param outputFile the file to write the merged data to.
+   * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithFileStream(
+      SpillInfo[] spills,
+      File outputFile,
+      @Nullable CompressionCodec compressionCodec) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+    OutputStream mergedFileOutputStream = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputStreams[i] = new FileInputStream(spills[i].file);
+      }
+      for (int partition = 0; partition < numPartitions; partition++) {
+        final long initialFileLength = outputFile.length();
+        mergedFileOutputStream =
+          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+        if (compressionCodec != null) {
+          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+        }
+
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          if (partitionLengthInSpill > 0) {
+            InputStream partitionInputStream =
+              new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+            if (compressionCodec != null) {
+              partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+            }
+            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+          }
+        }
+        mergedFileOutputStream.flush();
+        mergedFileOutputStream.close();
+        partitionLengths[partition] = (outputFile.length() - initialFileLength);
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (InputStream stream : spillInputStreams) {
+        Closeables.close(stream, threwException);
+      }
+      Closeables.close(mergedFileOutputStream, threwException);
+    }
+    return partitionLengths;
+  }
+
+  /**
+   * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+   * This is only safe when the IO compression codec and serializer support concatenation of
+   * serialized streams.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+    final long[] spillInputChannelPositions = new long[spills.length];
+    FileChannel mergedFileOutputChannel = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+      }
+      // This file needs to opened in append mode in order to work around a Linux kernel bug that
+      // affects transferTo; see SPARK-3948 for more details.
+      mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+      long bytesWrittenToMergedFile = 0;
+      for (int partition = 0; partition < numPartitions; partition++) {
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          long bytesToTransfer = partitionLengthInSpill;
+          final FileChannel spillInputChannel = spillInputChannels[i];
+          final long writeStartTime = System.nanoTime();
+          while (bytesToTransfer > 0) {
+            final long actualBytesTransferred = spillInputChannel.transferTo(
+              spillInputChannelPositions[i],
+              bytesToTransfer,
+              mergedFileOutputChannel);
+            spillInputChannelPositions[i] += actualBytesTransferred;
+            bytesToTransfer -= actualBytesTransferred;
+          }
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+          bytesWrittenToMergedFile += partitionLengthInSpill;
+          partitionLengths[partition] += partitionLengthInSpill;
+        }
+      }
+      // Check the position after transferTo loop to see if it is in the right position and raise an
+      // exception if it is incorrect. The position will not be increased to the expected length
+      // after calling transferTo in kernel version 2.6.32. This issue is described at
+      // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+      if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+        throw new IOException(
+          "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+            "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+            " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+            "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+            "to disable this NIO feature."
+        );
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (int i = 0; i < spills.length; i++) {
+        assert(spillInputChannelPositions[i] == spills[i].file.length());
+        Closeables.close(spillInputChannels[i], threwException);
+      }
+      Closeables.close(mergedFileOutputChannel, threwException);
+    }
+    return partitionLengths;
+  }
+
+  @Override
+  public Option<MapStatus> stop(boolean success) {
+    try {
+      // Update task metrics from accumulators (null in UnsafeShuffleWriterSuite)
+      Map<String, Accumulator<Object>> internalAccumulators =
+        taskContext.internalMetricsToAccumulators();
+      if (internalAccumulators != null) {
+        internalAccumulators.apply(InternalAccumulator.PEAK_EXECUTION_MEMORY())
+          .add(getPeakMemoryUsedBytes());
+      }
+
+      if (stopping) {
+        return Option.apply(null);
+      } else {
+        stopping = true;
+        if (success) {
+          if (mapStatus == null) {
+            throw new IllegalStateException("Cannot call stop(true) without having called write()");
+          }
+          return Option.apply(mapStatus);
+        } else {
+          // The map task failed, so delete our output data.
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+          return Option.apply(null);
+        }
+      }
+    } finally {
+      if (sorter != null) {
+        // If sorter is non-null, then this implies that we called stop() in response to an error,
+        // so we need to clean up memory and spill files created by the sorter
+        sorter.cleanupResources();
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
deleted file mode 100644
index 4ee6a82..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-/**
- * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
- * <p>
- * Within the long, the data is laid out as follows:
- * <pre>
- *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
- * </pre>
- * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
- * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
- * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
- * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
- * <p>
- * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
- * optimization to future work as it will require more careful design to ensure that addresses are
- * properly aligned (e.g. by padding records).
- */
-final class PackedRecordPointer {
-
-  static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27;  // 128 megabytes
-
-  /**
-   * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
-   */
-  static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1;  // 16777215
-
-  /** Bit mask for the lower 40 bits of a long. */
-  private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
-
-  /** Bit mask for the upper 24 bits of a long */
-  private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
-
-  /** Bit mask for the lower 27 bits of a long. */
-  private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
-
-  /** Bit mask for the lower 51 bits of a long. */
-  private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
-
-  /** Bit mask for the upper 13 bits of a long */
-  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
-
-  /**
-   * Pack a record address and partition id into a single word.
-   *
-   * @param recordPointer a record pointer encoded by TaskMemoryManager.
-   * @param partitionId a shuffle partition id (maximum value of 2^24).
-   * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
-   */
-  public static long packPointer(long recordPointer, int partitionId) {
-    assert (partitionId <= MAXIMUM_PARTITION_ID);
-    // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
-    // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
-    final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
-    final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
-    return (((long) partitionId) << 40) | compressedAddress;
-  }
-
-  private long packedRecordPointer;
-
-  public void set(long packedRecordPointer) {
-    this.packedRecordPointer = packedRecordPointer;
-  }
-
-  public int getPartitionId() {
-    return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
-  }
-
-  public long getRecordPointer() {
-    final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
-    final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
-    return pageNumber | offsetInPage;
-  }
-
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
deleted file mode 100644
index 7bac0dc..0000000
--- a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.io.File;
-
-import org.apache.spark.storage.TempShuffleBlockId;
-
-/**
- * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
- */
-final class SpillInfo {
-  final long[] partitionLengths;
-  final File file;
-  final TempShuffleBlockId blockId;
-
-  public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
-    this.partitionLengths = new long[numPartitions];
-    this.file = file;
-    this.blockId = blockId;
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/4] spark git commit: [SPARK-10708] Consolidate sort shuffle implementations

Posted by jo...@apache.org.
http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
deleted file mode 100644
index 87a786b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ /dev/null
@@ -1,273 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.io.InputStream
-import java.nio.IntBuffer
-import java.util.Comparator
-
-import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
-import org.apache.spark.storage.DiskBlockObjectWriter
-import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
-
-/**
- * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
- * its records upon insert and stores them as raw bytes.
- *
- * We use two data-structures to store the contents. The serialized records are stored in a
- * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
- * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
- * record. Each entry in the metadata buffer takes up a fixed amount of space.
- *
- * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
- * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
- * happen without following any pointers, which should minimize cache misses.
- *
- * Currently, only sorting by partition is supported.
- *
- * Each record is laid out inside the the metaBuffer as follows. keyStart, a long, is split across
- * two integers:
- *
- *   +-------------+------------+------------+-------------+
- *   |         keyStart         | keyValLen  | partitionId |
- *   +-------------+------------+------------+-------------+
- *
- * The buffer can support up to `536870911 (2 ^ 29 - 1)` records.
- *
- * @param metaInitialRecords The initial number of entries in the metadata buffer.
- * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
- * @param serializerInstance the serializer used for serializing inserted records.
- */
-private[spark] class PartitionedSerializedPairBuffer[K, V](
-    metaInitialRecords: Int,
-    kvBlockSize: Int,
-    serializerInstance: SerializerInstance)
-  extends WritablePartitionedPairCollection[K, V] with SizeTracker {
-
-  if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
-    throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
-      " Java-serialized objects.")
-  }
-
-  require(metaInitialRecords <= MAXIMUM_RECORDS,
-    s"Can't make capacity bigger than ${MAXIMUM_RECORDS} records")
-  private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
-
-  private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
-  private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
-  private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
-
-  def insert(partition: Int, key: K, value: V): Unit = {
-    if (metaBuffer.position == metaBuffer.capacity) {
-      growMetaBuffer()
-    }
-
-    val keyStart = kvBuffer.size
-    kvSerializationStream.writeKey[Any](key)
-    kvSerializationStream.writeValue[Any](value)
-    kvSerializationStream.flush()
-    val keyValLen = (kvBuffer.size - keyStart).toInt
-
-    // keyStart, a long, gets split across two ints
-    metaBuffer.put(keyStart.toInt)
-    metaBuffer.put((keyStart >> 32).toInt)
-    metaBuffer.put(keyValLen)
-    metaBuffer.put(partition)
-  }
-
-  /** Double the size of the array because we've reached capacity */
-  private def growMetaBuffer(): Unit = {
-    if (metaBuffer.capacity >= MAXIMUM_META_BUFFER_CAPACITY) {
-      throw new IllegalStateException(s"Can't insert more than ${MAXIMUM_RECORDS} records")
-    }
-    val newCapacity =
-      if (metaBuffer.capacity * 2 < 0 || metaBuffer.capacity * 2 > MAXIMUM_META_BUFFER_CAPACITY) {
-        // Overflow
-        MAXIMUM_META_BUFFER_CAPACITY
-      } else {
-        metaBuffer.capacity * 2
-      }
-    val newMetaBuffer = IntBuffer.allocate(newCapacity)
-    newMetaBuffer.put(metaBuffer.array)
-    metaBuffer = newMetaBuffer
-  }
-
-  /** Iterate through the data in a given order. For this class this is not really destructive. */
-  override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
-    : Iterator[((Int, K), V)] = {
-    sort(keyComparator)
-    val is = orderedInputStream
-    val deserStream = serializerInstance.deserializeStream(is)
-    new Iterator[((Int, K), V)] {
-      var metaBufferPos = 0
-      def hasNext: Boolean = metaBufferPos < metaBuffer.position
-      def next(): ((Int, K), V) = {
-        val key = deserStream.readKey[Any]().asInstanceOf[K]
-        val value = deserStream.readValue[Any]().asInstanceOf[V]
-        val partition = metaBuffer.get(metaBufferPos + PARTITION)
-        metaBufferPos += RECORD_SIZE
-        ((partition, key), value)
-      }
-    }
-  }
-
-  override def estimateSize: Long = metaBuffer.capacity * 4L + kvBuffer.capacity
-
-  override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
-    : WritablePartitionedIterator = {
-    sort(keyComparator)
-    new WritablePartitionedIterator {
-      // current position in the meta buffer in ints
-      var pos = 0
-
-      def writeNext(writer: DiskBlockObjectWriter): Unit = {
-        val keyStart = getKeyStartPos(metaBuffer, pos)
-        val keyValLen = metaBuffer.get(pos + KEY_VAL_LEN)
-        pos += RECORD_SIZE
-        kvBuffer.read(keyStart, writer, keyValLen)
-        writer.recordWritten()
-      }
-      def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
-      def hasNext(): Boolean = pos < metaBuffer.position
-    }
-  }
-
-  // Visible for testing
-  def orderedInputStream: OrderedInputStream = {
-    new OrderedInputStream(metaBuffer, kvBuffer)
-  }
-
-  private def sort(keyComparator: Option[Comparator[K]]): Unit = {
-    val comparator = if (keyComparator.isEmpty) {
-      new Comparator[Int]() {
-        def compare(partition1: Int, partition2: Int): Int = {
-          partition1 - partition2
-        }
-      }
-    } else {
-      throw new UnsupportedOperationException()
-    }
-
-    val sorter = new Sorter(new SerializedSortDataFormat)
-    sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
-  }
-}
-
-private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
-    extends InputStream {
-
-  import PartitionedSerializedPairBuffer._
-
-  private var metaBufferPos = 0
-  private var kvBufferPos =
-    if (metaBuffer.position > 0) getKeyStartPos(metaBuffer, metaBufferPos) else 0
-
-  override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
-
-  override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
-    if (metaBufferPos >= metaBuffer.position) {
-      return -1
-    }
-    val bytesRemainingInRecord = (metaBuffer.get(metaBufferPos + KEY_VAL_LEN) -
-      (kvBufferPos - getKeyStartPos(metaBuffer, metaBufferPos))).toInt
-    val toRead = math.min(bytesRemainingInRecord, len)
-    kvBuffer.read(kvBufferPos, bytes, offs, toRead)
-    if (toRead == bytesRemainingInRecord) {
-      metaBufferPos += RECORD_SIZE
-      if (metaBufferPos < metaBuffer.position) {
-        kvBufferPos = getKeyStartPos(metaBuffer, metaBufferPos)
-      }
-    } else {
-      kvBufferPos += toRead
-    }
-    toRead
-  }
-
-  override def read(): Int = {
-    throw new UnsupportedOperationException()
-  }
-}
-
-private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
-
-  private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
-
-  /** Return the sort key for the element at the given index. */
-  override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
-    metaBuffer.get(pos * RECORD_SIZE + PARTITION)
-  }
-
-  /** Swap two elements. */
-  override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
-    val iOff = pos0 * RECORD_SIZE
-    val jOff = pos1 * RECORD_SIZE
-    System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
-    System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
-    System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
-  }
-
-  /** Copy a single element from src(srcPos) to dst(dstPos). */
-  override def copyElement(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
-  }
-
-  /**
-   * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
-   * Overlapping ranges are allowed.
-   */
-  override def copyRange(
-      src: IntBuffer,
-      srcPos: Int,
-      dst: IntBuffer,
-      dstPos: Int,
-      length: Int): Unit = {
-    val srcOff = srcPos * RECORD_SIZE
-    val dstOff = dstPos * RECORD_SIZE
-    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
-  }
-
-  /**
-   * Allocates a Buffer that can hold up to 'length' elements.
-   * All elements of the buffer should be considered invalid until data is explicitly copied in.
-   */
-  override def allocate(length: Int): IntBuffer = {
-    IntBuffer.allocate(length * RECORD_SIZE)
-  }
-}
-
-private object PartitionedSerializedPairBuffer {
-  val KEY_START = 0 // keyStart, a long, gets split across two ints
-  val KEY_VAL_LEN = 2
-  val PARTITION = 3
-  val RECORD_SIZE = PARTITION + 1 // num ints of metadata
-
-  val MAXIMUM_RECORDS = Int.MaxValue / RECORD_SIZE // (2 ^ 29) - 1
-  val MAXIMUM_META_BUFFER_CAPACITY = MAXIMUM_RECORDS * RECORD_SIZE // (2 ^ 31) - 4
-
-  def getKeyStartPos(metaBuffer: IntBuffer, metaBufferPos: Int): Long = {
-    val lower32 = metaBuffer.get(metaBufferPos + KEY_START)
-    val upper32 = metaBuffer.get(metaBufferPos + KEY_START + 1)
-    (upper32.toLong << 32) | (lower32 & 0xFFFFFFFFL)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
new file mode 100644
index 0000000..232ae4d
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import org.apache.spark.shuffle.sort.PackedRecordPointer;
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.sort.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+  @Test
+  public void heap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void offHeap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock page0 = memoryManager.allocatePage(128);
+    final MemoryBlock page1 = memoryManager.allocatePage(128);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void maximumPartitionIdCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+  }
+
+  @Test
+  public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    try {
+      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
+    } catch (AssertionError e ) {
+      // pass
+    }
+  }
+
+  @Test
+  public void maximumOffsetInPageCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(address, packedPointer.getRecordPointer());
+  }
+
+  @Test
+  public void offsetsPastMaxOffsetInPageWillOverflow() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(0, packedPointer.getRecordPointer());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000..1ef3c5f
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/ShuffleInMemorySorterSuite.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class ShuffleInMemorySorterSuite {
+
+  private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+    final byte[] strBytes = new byte[strLength];
+    Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
+    return new String(strBytes);
+  }
+
+  @Test
+  public void testSortingEmptyInput() {
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(100);
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+    assert(!iter.hasNext());
+  }
+
+  @Test
+  public void testBasicSorting() throws Exception {
+    final String[] dataToSort = new String[] {
+      "Boba",
+      "Pearls",
+      "Tapioca",
+      "Taho",
+      "Condensed Milk",
+      "Jasmine",
+      "Milk Tea",
+      "Lychee",
+      "Mango"
+    };
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+    final Object baseObject = dataPage.getBaseObject();
+    final ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+    // Write the records into the data page and store pointers into the sorter
+    long position = dataPage.getBaseOffset();
+    for (String str : dataToSort) {
+      final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+      final byte[] strBytes = str.getBytes("utf-8");
+      Platform.putInt(baseObject, position, strBytes.length);
+      position += 4;
+      Platform.copyMemory(
+        strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
+      position += strBytes.length;
+      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+    }
+
+    // Sort the records
+    final ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+    int prevPartitionId = -1;
+    Arrays.sort(dataToSort);
+    for (int i = 0; i < dataToSort.length; i++) {
+      Assert.assertTrue(iter.hasNext());
+      iter.loadNext();
+      final int partitionId = iter.packedRecordPointer.getPartitionId();
+      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+        partitionId >= prevPartitionId);
+      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+      final int recordLength = Platform.getInt(
+        memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+      final String str = getStringFromDataPage(
+        memoryManager.getPage(recordAddress),
+        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+        recordLength);
+      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+    }
+    Assert.assertFalse(iter.hasNext());
+  }
+
+  @Test
+  public void testSortingManyNumbers() throws Exception {
+    ShuffleInMemorySorter sorter = new ShuffleInMemorySorter(4);
+    int[] numbersToSort = new int[128000];
+    Random random = new Random(16);
+    for (int i = 0; i < numbersToSort.length; i++) {
+      numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+      sorter.insertRecord(0, numbersToSort[i]);
+    }
+    Arrays.sort(numbersToSort);
+    int[] sorterResult = new int[numbersToSort.length];
+    ShuffleInMemorySorter.ShuffleSorterIterator iter = sorter.getSortedIterator();
+    int j = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+      j += 1;
+    }
+    Assert.assertArrayEquals(numbersToSort, sorterResult);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000..29d9823
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,560 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.Iterators;
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.sort.SerializedShuffleHandle;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+  static final int NUM_PARTITITONS = 4;
+  final TaskMemoryManager taskMemoryManager =
+    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+  File mergedOutputFile;
+  File tempDir;
+  long[] partitionSizesInMergedFile;
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  SparkConf conf;
+  final Serializer serializer = new KryoSerializer(new SparkConf());
+  TaskMetrics taskMetrics;
+
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+  private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      if (conf.getBoolean("spark.shuffle.compress", true)) {
+        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+      } else {
+        return stream;
+      }
+    }
+  }
+
+  @After
+  public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (leakedMemory != 0) {
+      fail("Test leaked " + leakedMemory + " bytes of managed memory");
+    }
+  }
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    tempDir = Utils.createTempDir("test", "test");
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+    partitionSizesInMergedFile = null;
+    spillFilesCreated.clear();
+    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
+    taskMetrics = new TaskMetrics();
+
+    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
+
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+      new Answer<InputStream>() {
+        @Override
+        public InputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          InputStream is = (InputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+          } else {
+            return is;
+          }
+        }
+      }
+    );
+
+    when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+      new Answer<OutputStream>() {
+        @Override
+        public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          OutputStream os = (OutputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+          } else {
+            return os;
+          }
+        }
+      }
+    );
+
+    when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        return null;
+      }
+    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer<Tuple2<TempShuffleBlockId, File>>() {
+        @Override
+        public Tuple2<TempShuffleBlockId, File> answer(
+          InvocationOnMock invocationOnMock) throws Throwable {
+          TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+          File file = File.createTempFile("spillFile", ".spill", tempDir);
+          spillFilesCreated.add(file);
+          return Tuple2$.MODULE$.apply(blockId, file);
+        }
+      });
+
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
+
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+      boolean transferToEnabled) throws IOException {
+    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+    return new UnsafeShuffleWriter<Object, Object>(
+      blockManager,
+      shuffleBlockResolver,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      new SerializedShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      0, // map id
+      taskContext,
+      conf
+    );
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
+  private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+    final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+    long startOffset = 0;
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      final long partitionSize = partitionSizesInMergedFile[i];
+      if (partitionSize > 0) {
+        InputStream in = new FileInputStream(mergedOutputFile);
+        ByteStreams.skipFully(in, startOffset);
+        in = new LimitedInputStream(in, partitionSize);
+        if (conf.getBoolean("spark.shuffle.compress", true)) {
+          in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+        }
+        DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+        Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+        while (records.hasNext()) {
+          Tuple2<Object, Object> record = records.next();
+          assertEquals(i, hashPartitioner.getPartition(record._1()));
+          recordsList.add(record);
+        }
+        recordsStream.close();
+        startOffset += partitionSize;
+      }
+    }
+    return recordsList;
+  }
+
+  @Test(expected=IllegalStateException.class)
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+    createWriter(false).stop(true);
+  }
+
+  @Test
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+    createWriter(false).stop(false);
+  }
+
+  class PandaException extends RuntimeException {
+  }
+
+  @Test(expected=PandaException.class)
+  public void writeFailurePropagates() throws Exception {
+    class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
+      @Override public boolean hasNext() {
+        throw new PandaException();
+      }
+      @Override public Product2<Object, Object> next() {
+        return null;
+      }
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(new BadRecords());
+  }
+
+  @Test
+  public void writeEmptyIterator() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+  }
+
+  @Test
+  public void writeWithoutSpilling() throws Exception {
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(dataToWrite.iterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      // All partitions should be the same size:
+      assertEquals(partitionSizesInMergedFile[0], size);
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      String compressionCodecName) throws IOException {
+    if (compressionCodecName != null) {
+      conf.set("spark.shuffle.compress", "true");
+      conf.set("spark.io.compression.codec", compressionCodecName);
+    } else {
+      conf.set("spark.shuffle.compress", "false");
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.insertRecordIntoSorter(dataToWrite.get(0));
+    writer.insertRecordIntoSorter(dataToWrite.get(1));
+    writer.insertRecordIntoSorter(dataToWrite.get(2));
+    writer.insertRecordIntoSorter(dataToWrite.get(3));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(dataToWrite.get(4));
+    writer.insertRecordIntoSorter(dataToWrite.get(5));
+    writer.closeAndWriteOutput();
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertEquals(2, spillFilesCreated.size());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZF() throws Exception {
+    testMergingSpills(true, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+    testMergingSpills(false, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+    testMergingSpills(true, null);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+    testMergingSpills(false, null);
+  }
+
+  @Test
+  public void writeEnoughDataToTriggerSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to allocate new data page
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+    for (int i = 0; i < 128 + 1; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to grow sort buffer
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    final byte[] bytes = new byte[(int) (ShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    new Random(42).nextBytes(bytes);
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
+    // We should be able to write a record that's right _at_ the max record size
+    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
+    new Random(42).nextBytes(atMaxRecordSize);
+    dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
+    // Inserting a record that's larger than the max record size
+    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
+    new Random(42).nextBytes(exceedsMaxRecordSize);
+    dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.stop(false);
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void testPeakMemoryUsed() throws Exception {
+    final long recordLengthBytes = 8;
+    final long pageSizeBytes = 256;
+    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
+    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
+    final UnsafeShuffleWriter<Object, Object> writer =
+      new UnsafeShuffleWriter<Object, Object>(
+        blockManager,
+        shuffleBlockResolver,
+        taskMemoryManager,
+        shuffleMemoryManager,
+        new SerializedShuffleHandle<>(0, 1, shuffleDep),
+        0, // map id
+        taskContext,
+        conf);
+
+    // Peak memory should be monotonically increasing. More specifically, every time
+    // we allocate a new page it should increase by exactly the size of the page.
+    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
+    long newPeakMemory;
+    try {
+      for (int i = 0; i < numRecordsPerPage * 10; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+        newPeakMemory = writer.getPeakMemoryUsedBytes();
+        if (i % numRecordsPerPage == 0 && i != 0) {
+          // The first page is allocated in constructor, another page will be allocated after
+          // every numRecordsPerPage records (peak memory should change).
+          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
+        } else {
+          assertEquals(previousPeakMemory, newPeakMemory);
+        }
+        previousPeakMemory = newPeakMemory;
+      }
+
+      // Spilling should not change peak memory
+      writer.forceSorterToSpill();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+      for (int i = 0; i < numRecordsPerPage; i++) {
+        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+      }
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+
+      // Closing the writer should not change peak memory
+      writer.closeAndWriteOutput();
+      newPeakMemory = writer.getPeakMemoryUsedBytes();
+      assertEquals(previousPeakMemory, newPeakMemory);
+    } finally {
+      writer.stop(false);
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
deleted file mode 100644
index 934b7e0..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
+++ /dev/null
@@ -1,101 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import org.junit.Test;
-import static org.junit.Assert.*;
-
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
-
-public class PackedRecordPointerSuite {
-
-  @Test
-  public void heap() {
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
-    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
-      page1.getBaseOffset() + 42);
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
-    assertEquals(360, packedPointer.getPartitionId());
-    final long recordPointer = packedPointer.getRecordPointer();
-    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
-    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
-    assertEquals(addressInPage1, recordPointer);
-    memoryManager.cleanUpAllAllocatedMemory();
-  }
-
-  @Test
-  public void offHeap() {
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
-    final MemoryBlock page0 = memoryManager.allocatePage(128);
-    final MemoryBlock page1 = memoryManager.allocatePage(128);
-    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
-      page1.getBaseOffset() + 42);
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
-    assertEquals(360, packedPointer.getPartitionId());
-    final long recordPointer = packedPointer.getRecordPointer();
-    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
-    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
-    assertEquals(addressInPage1, recordPointer);
-    memoryManager.cleanUpAllAllocatedMemory();
-  }
-
-  @Test
-  public void maximumPartitionIdCanBeEncoded() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
-    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
-  }
-
-  @Test
-  public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    try {
-      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
-      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
-      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
-    } catch (AssertionError e ) {
-      // pass
-    }
-  }
-
-  @Test
-  public void maximumOffsetInPageCanBeEncoded() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
-    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
-    assertEquals(address, packedPointer.getRecordPointer());
-  }
-
-  @Test
-  public void offsetsPastMaxOffsetInPageWillOverflow() {
-    PackedRecordPointer packedPointer = new PackedRecordPointer();
-    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
-    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
-    assertEquals(0, packedPointer.getRecordPointer());
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
deleted file mode 100644
index 40fefe2..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
+++ /dev/null
@@ -1,124 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.util.Arrays;
-import java.util.Random;
-
-import org.junit.Assert;
-import org.junit.Test;
-
-import org.apache.spark.HashPartitioner;
-import org.apache.spark.unsafe.Platform;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.MemoryBlock;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-
-public class UnsafeShuffleInMemorySorterSuite {
-
-  private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
-    final byte[] strBytes = new byte[strLength];
-    Platform.copyMemory(baseObject, baseOffset, strBytes, Platform.BYTE_ARRAY_OFFSET, strLength);
-    return new String(strBytes);
-  }
-
-  @Test
-  public void testSortingEmptyInput() {
-    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
-    assert(!iter.hasNext());
-  }
-
-  @Test
-  public void testBasicSorting() throws Exception {
-    final String[] dataToSort = new String[] {
-      "Boba",
-      "Pearls",
-      "Tapioca",
-      "Taho",
-      "Condensed Milk",
-      "Jasmine",
-      "Milk Tea",
-      "Lychee",
-      "Mango"
-    };
-    final TaskMemoryManager memoryManager =
-      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
-    final Object baseObject = dataPage.getBaseObject();
-    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
-    final HashPartitioner hashPartitioner = new HashPartitioner(4);
-
-    // Write the records into the data page and store pointers into the sorter
-    long position = dataPage.getBaseOffset();
-    for (String str : dataToSort) {
-      final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
-      final byte[] strBytes = str.getBytes("utf-8");
-      Platform.putInt(baseObject, position, strBytes.length);
-      position += 4;
-      Platform.copyMemory(
-        strBytes, Platform.BYTE_ARRAY_OFFSET, baseObject, position, strBytes.length);
-      position += strBytes.length;
-      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
-    }
-
-    // Sort the records
-    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
-    int prevPartitionId = -1;
-    Arrays.sort(dataToSort);
-    for (int i = 0; i < dataToSort.length; i++) {
-      Assert.assertTrue(iter.hasNext());
-      iter.loadNext();
-      final int partitionId = iter.packedRecordPointer.getPartitionId();
-      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
-      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
-        partitionId >= prevPartitionId);
-      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
-      final int recordLength = Platform.getInt(
-        memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
-      final String str = getStringFromDataPage(
-        memoryManager.getPage(recordAddress),
-        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
-        recordLength);
-      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
-    }
-    Assert.assertFalse(iter.hasNext());
-  }
-
-  @Test
-  public void testSortingManyNumbers() throws Exception {
-    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
-    int[] numbersToSort = new int[128000];
-    Random random = new Random(16);
-    for (int i = 0; i < numbersToSort.length; i++) {
-      numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
-      sorter.insertRecord(0, numbersToSort[i]);
-    }
-    Arrays.sort(numbersToSort);
-    int[] sorterResult = new int[numbersToSort.length];
-    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
-    int j = 0;
-    while (iter.hasNext()) {
-      iter.loadNext();
-      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
-      j += 1;
-    }
-    Assert.assertArrayEquals(numbersToSort, sorterResult);
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
deleted file mode 100644
index d218344..0000000
--- a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
+++ /dev/null
@@ -1,560 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.shuffle.unsafe;
-
-import java.io.*;
-import java.nio.ByteBuffer;
-import java.util.*;
-
-import scala.*;
-import scala.collection.Iterator;
-import scala.reflect.ClassTag;
-import scala.runtime.AbstractFunction1;
-
-import com.google.common.collect.Iterators;
-import com.google.common.collect.HashMultiset;
-import com.google.common.io.ByteStreams;
-import org.junit.After;
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.Mock;
-import org.mockito.MockitoAnnotations;
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-import static org.hamcrest.MatcherAssert.assertThat;
-import static org.hamcrest.Matchers.greaterThan;
-import static org.hamcrest.Matchers.lessThan;
-import static org.junit.Assert.*;
-import static org.mockito.AdditionalAnswers.returnsFirstArg;
-import static org.mockito.Answers.RETURNS_SMART_NULLS;
-import static org.mockito.Mockito.*;
-
-import org.apache.spark.*;
-import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZ4CompressionCodec;
-import org.apache.spark.io.LZFCompressionCodec;
-import org.apache.spark.io.SnappyCompressionCodec;
-import org.apache.spark.executor.ShuffleWriteMetrics;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.network.util.LimitedInputStream;
-import org.apache.spark.serializer.*;
-import org.apache.spark.scheduler.MapStatus;
-import org.apache.spark.shuffle.IndexShuffleBlockResolver;
-import org.apache.spark.shuffle.ShuffleMemoryManager;
-import org.apache.spark.storage.*;
-import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
-import org.apache.spark.unsafe.memory.MemoryAllocator;
-import org.apache.spark.unsafe.memory.TaskMemoryManager;
-import org.apache.spark.util.Utils;
-
-public class UnsafeShuffleWriterSuite {
-
-  static final int NUM_PARTITITONS = 4;
-  final TaskMemoryManager taskMemoryManager =
-    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
-  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
-  File mergedOutputFile;
-  File tempDir;
-  long[] partitionSizesInMergedFile;
-  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
-  SparkConf conf;
-  final Serializer serializer = new KryoSerializer(new SparkConf());
-  TaskMetrics taskMetrics;
-
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
-  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
-  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
-  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
-  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
-  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
-
-  private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
-    @Override
-    public OutputStream apply(OutputStream stream) {
-      if (conf.getBoolean("spark.shuffle.compress", true)) {
-        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
-      } else {
-        return stream;
-      }
-    }
-  }
-
-  @After
-  public void tearDown() {
-    Utils.deleteRecursively(tempDir);
-    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
-    if (leakedMemory != 0) {
-      fail("Test leaked " + leakedMemory + " bytes of managed memory");
-    }
-  }
-
-  @Before
-  @SuppressWarnings("unchecked")
-  public void setUp() throws IOException {
-    MockitoAnnotations.initMocks(this);
-    tempDir = Utils.createTempDir("test", "test");
-    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
-    partitionSizesInMergedFile = null;
-    spillFilesCreated.clear();
-    conf = new SparkConf().set("spark.buffer.pageSize", "128m");
-    taskMetrics = new TaskMetrics();
-
-    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(128L * 1024 * 1024);
-
-    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
-    when(blockManager.getDiskWriter(
-      any(BlockId.class),
-      any(File.class),
-      any(SerializerInstance.class),
-      anyInt(),
-      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
-      @Override
-      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
-        Object[] args = invocationOnMock.getArguments();
-
-        return new DiskBlockObjectWriter(
-          (File) args[1],
-          (SerializerInstance) args[2],
-          (Integer) args[3],
-          new CompressStream(),
-          false,
-          (ShuffleWriteMetrics) args[4]
-        );
-      }
-    });
-    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
-      new Answer<InputStream>() {
-        @Override
-        public InputStream answer(InvocationOnMock invocation) throws Throwable {
-          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
-          InputStream is = (InputStream) invocation.getArguments()[1];
-          if (conf.getBoolean("spark.shuffle.compress", true)) {
-            return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
-          } else {
-            return is;
-          }
-        }
-      }
-    );
-
-    when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
-      new Answer<OutputStream>() {
-        @Override
-        public OutputStream answer(InvocationOnMock invocation) throws Throwable {
-          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
-          OutputStream os = (OutputStream) invocation.getArguments()[1];
-          if (conf.getBoolean("spark.shuffle.compress", true)) {
-            return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
-          } else {
-            return os;
-          }
-        }
-      }
-    );
-
-    when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
-    doAnswer(new Answer<Void>() {
-      @Override
-      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
-        partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
-        return null;
-      }
-    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
-
-    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
-      new Answer<Tuple2<TempShuffleBlockId, File>>() {
-        @Override
-        public Tuple2<TempShuffleBlockId, File> answer(
-          InvocationOnMock invocationOnMock) throws Throwable {
-          TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
-          File file = File.createTempFile("spillFile", ".spill", tempDir);
-          spillFilesCreated.add(file);
-          return Tuple2$.MODULE$.apply(blockId, file);
-        }
-      });
-
-    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
-    when(taskContext.internalMetricsToAccumulators()).thenReturn(null);
-
-    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
-    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
-  }
-
-  private UnsafeShuffleWriter<Object, Object> createWriter(
-      boolean transferToEnabled) throws IOException {
-    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
-    return new UnsafeShuffleWriter<Object, Object>(
-      blockManager,
-      shuffleBlockResolver,
-      taskMemoryManager,
-      shuffleMemoryManager,
-      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
-      0, // map id
-      taskContext,
-      conf
-    );
-  }
-
-  private void assertSpillFilesWereCleanedUp() {
-    for (File spillFile : spillFilesCreated) {
-      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
-        spillFile.exists());
-    }
-  }
-
-  private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
-    final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
-    long startOffset = 0;
-    for (int i = 0; i < NUM_PARTITITONS; i++) {
-      final long partitionSize = partitionSizesInMergedFile[i];
-      if (partitionSize > 0) {
-        InputStream in = new FileInputStream(mergedOutputFile);
-        ByteStreams.skipFully(in, startOffset);
-        in = new LimitedInputStream(in, partitionSize);
-        if (conf.getBoolean("spark.shuffle.compress", true)) {
-          in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
-        }
-        DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
-        Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
-        while (records.hasNext()) {
-          Tuple2<Object, Object> record = records.next();
-          assertEquals(i, hashPartitioner.getPartition(record._1()));
-          recordsList.add(record);
-        }
-        recordsStream.close();
-        startOffset += partitionSize;
-      }
-    }
-    return recordsList;
-  }
-
-  @Test(expected=IllegalStateException.class)
-  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
-    createWriter(false).stop(true);
-  }
-
-  @Test
-  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
-    createWriter(false).stop(false);
-  }
-
-  class PandaException extends RuntimeException {
-  }
-
-  @Test(expected=PandaException.class)
-  public void writeFailurePropagates() throws Exception {
-    class BadRecords extends scala.collection.AbstractIterator<Product2<Object, Object>> {
-      @Override public boolean hasNext() {
-        throw new PandaException();
-      }
-      @Override public Product2<Object, Object> next() {
-        return null;
-      }
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(new BadRecords());
-  }
-
-  @Test
-  public void writeEmptyIterator() throws Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(Iterators.<Product2<Object, Object>>emptyIterator());
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
-    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
-    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
-    assertEquals(0, taskMetrics.diskBytesSpilled());
-    assertEquals(0, taskMetrics.memoryBytesSpilled());
-  }
-
-  @Test
-  public void writeWithoutSpilling() throws Exception {
-    // In this example, each partition should have exactly one record:
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    for (int i = 0; i < NUM_PARTITITONS; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
-    writer.write(dataToWrite.iterator());
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-
-    long sumOfPartitionSizes = 0;
-    for (long size: partitionSizesInMergedFile) {
-      // All partitions should be the same size:
-      assertEquals(partitionSizesInMergedFile[0], size);
-      sumOfPartitionSizes += size;
-    }
-    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
-    assertEquals(0, taskMetrics.diskBytesSpilled());
-    assertEquals(0, taskMetrics.memoryBytesSpilled());
-    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  private void testMergingSpills(
-      boolean transferToEnabled,
-      String compressionCodecName) throws IOException {
-    if (compressionCodecName != null) {
-      conf.set("spark.shuffle.compress", "true");
-      conf.set("spark.io.compression.codec", compressionCodecName);
-    } else {
-      conf.set("spark.shuffle.compress", "false");
-    }
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    writer.insertRecordIntoSorter(dataToWrite.get(0));
-    writer.insertRecordIntoSorter(dataToWrite.get(1));
-    writer.insertRecordIntoSorter(dataToWrite.get(2));
-    writer.insertRecordIntoSorter(dataToWrite.get(3));
-    writer.forceSorterToSpill();
-    writer.insertRecordIntoSorter(dataToWrite.get(4));
-    writer.insertRecordIntoSorter(dataToWrite.get(5));
-    writer.closeAndWriteOutput();
-    final Option<MapStatus> mapStatus = writer.stop(true);
-    assertTrue(mapStatus.isDefined());
-    assertTrue(mergedOutputFile.exists());
-    assertEquals(2, spillFilesCreated.size());
-
-    long sumOfPartitionSizes = 0;
-    for (long size: partitionSizesInMergedFile) {
-      sumOfPartitionSizes += size;
-    }
-    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
-
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndLZF() throws Exception {
-    testMergingSpills(true, LZFCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
-    testMergingSpills(false, LZFCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
-    testMergingSpills(true, LZ4CompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
-    testMergingSpills(false, LZ4CompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
-    testMergingSpills(true, SnappyCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
-    testMergingSpills(false, SnappyCompressionCodec.class.getName());
-  }
-
-  @Test
-  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
-    testMergingSpills(true, null);
-  }
-
-  @Test
-  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
-    testMergingSpills(false, null);
-  }
-
-  @Test
-  public void writeEnoughDataToTriggerSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to allocate new data page
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
-    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
-    for (int i = 0; i < 128 + 1; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
-    }
-    writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
-    assertEquals(2, spillFilesCreated.size());
-    writer.stop(true);
-    readRecordsFromFile();
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
-    when(shuffleMemoryManager.tryToAcquire(anyLong()))
-      .then(returnsFirstArg()) // Allocate initial sort buffer
-      .then(returnsFirstArg()) // Allocate initial data page
-      .thenReturn(0L) // Deny request to grow sort buffer
-      .then(returnsFirstArg());  // Grant new sort buffer and data page.
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
-    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
-      dataToWrite.add(new Tuple2<Object, Object>(i, i));
-    }
-    writer.write(dataToWrite.iterator());
-    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
-    assertEquals(2, spillFilesCreated.size());
-    writer.stop(true);
-    readRecordsFromFile();
-    assertSpillFilesWereCleanedUp();
-    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
-    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
-    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
-    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
-    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
-    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
-  }
-
-  @Test
-  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite =
-      new ArrayList<Product2<Object, Object>>();
-    final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
-    new Random(42).nextBytes(bytes);
-    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
-    writer.write(dataToWrite.iterator());
-    writer.stop(true);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
-    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(new byte[1])));
-    // We should be able to write a record that's right _at_ the max record size
-    final byte[] atMaxRecordSize = new byte[writer.maxRecordSizeBytes()];
-    new Random(42).nextBytes(atMaxRecordSize);
-    dataToWrite.add(new Tuple2<Object, Object>(2, ByteBuffer.wrap(atMaxRecordSize)));
-    // Inserting a record that's larger than the max record size
-    final byte[] exceedsMaxRecordSize = new byte[writer.maxRecordSizeBytes() + 1];
-    new Random(42).nextBytes(exceedsMaxRecordSize);
-    dataToWrite.add(new Tuple2<Object, Object>(3, ByteBuffer.wrap(exceedsMaxRecordSize)));
-    writer.write(dataToWrite.iterator());
-    writer.stop(true);
-    assertEquals(
-      HashMultiset.create(dataToWrite),
-      HashMultiset.create(readRecordsFromFile()));
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
-    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
-    writer.forceSorterToSpill();
-    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
-    writer.stop(false);
-    assertSpillFilesWereCleanedUp();
-  }
-
-  @Test
-  public void testPeakMemoryUsed() throws Exception {
-    final long recordLengthBytes = 8;
-    final long pageSizeBytes = 256;
-    final long numRecordsPerPage = pageSizeBytes / recordLengthBytes;
-    when(shuffleMemoryManager.pageSizeBytes()).thenReturn(pageSizeBytes);
-    final UnsafeShuffleWriter<Object, Object> writer =
-      new UnsafeShuffleWriter<Object, Object>(
-        blockManager,
-        shuffleBlockResolver,
-        taskMemoryManager,
-        shuffleMemoryManager,
-        new UnsafeShuffleHandle<>(0, 1, shuffleDep),
-        0, // map id
-        taskContext,
-        conf);
-
-    // Peak memory should be monotonically increasing. More specifically, every time
-    // we allocate a new page it should increase by exactly the size of the page.
-    long previousPeakMemory = writer.getPeakMemoryUsedBytes();
-    long newPeakMemory;
-    try {
-      for (int i = 0; i < numRecordsPerPage * 10; i++) {
-        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-        newPeakMemory = writer.getPeakMemoryUsedBytes();
-        if (i % numRecordsPerPage == 0 && i != 0) {
-          // The first page is allocated in constructor, another page will be allocated after
-          // every numRecordsPerPage records (peak memory should change).
-          assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
-        } else {
-          assertEquals(previousPeakMemory, newPeakMemory);
-        }
-        previousPeakMemory = newPeakMemory;
-      }
-
-      // Spilling should not change peak memory
-      writer.forceSorterToSpill();
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-      for (int i = 0; i < numRecordsPerPage; i++) {
-        writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
-      }
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-
-      // Closing the writer should not change peak memory
-      writer.closeAndWriteOutput();
-      newPeakMemory = writer.getPeakMemoryUsedBytes();
-      assertEquals(previousPeakMemory, newPeakMemory);
-    } finally {
-      writer.stop(false);
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 6335817..b8ab227 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
 
 package org.apache.spark
 
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
 import org.scalatest.BeforeAndAfterAll
 
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
 class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
 
   // This test suite should run all tests in ShuffleSuite with sort-based shuffle.
 
+  private var tempDir: File = _
+
   override def beforeAll() {
     conf.set("spark.shuffle.manager", "sort")
   }
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    conf.set("spark.local.dir", tempDir.getAbsolutePath)
+  }
+
+  override def afterEach(): Unit = {
+    try {
+      Utils.deleteRecursively(tempDir)
+    } finally {
+      super.afterEach()
+    }
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the new serialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+      .setSerializer(new KryoSerializer(conf))
+    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
+    sc = new SparkContext("local", "test", conf)
+    // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
+    val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+    val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+      .setSerializer(new JavaSerializer(conf))
+    val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+    assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+    ensureFilesAreCleanedUp(shuffledRdd)
+  }
+
+  private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
+    def getAllFiles: Set[File] =
+      FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+    val filesBeforeShuffle = getAllFiles
+    // Force the shuffle to be performed
+    shuffledRdd.count()
+    // Ensure that the shuffle actually created files that will need to be cleaned up
+    val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+    filesCreatedByShuffle.map(_.getName) should be
+    Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+    // Check that the cleanup actually removes the files
+    sc.env.blockManager.master.removeShuffle(0, blocking = true)
+    for (file <- filesCreatedByShuffle) {
+      assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f6d06adf/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 5b01ddb..3816b8c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -1062,10 +1062,10 @@ class DAGSchedulerSuite
    */
   test("don't submit stage until its dependencies map outputs are registered (SPARK-5259)") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
     val firstShuffleId = firstShuffleDep.shuffleId
     val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
-    val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
     submit(reduceRdd, Array(0))
 
@@ -1175,7 +1175,7 @@ class DAGSchedulerSuite
    */
   test("register map outputs correctly after ExecutorLost and task Resubmitted") {
     val firstRDD = new MyRDD(sc, 3, Nil)
-    val firstShuffleDep = new ShuffleDependency(firstRDD, null)
+    val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(2))
     val reduceRdd = new MyRDD(sc, 5, List(firstShuffleDep))
     submit(reduceRdd, Array(0))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org