You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/14 02:07:36 UTC

[1/2] spark git commit: [SPARK-7081] Faster sort-based shuffle path using binary processing cache-aware sort

Repository: spark
Updated Branches:
  refs/heads/master 61d1e87c0 -> 73bed408f


http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
new file mode 100644
index 0000000..f2bfef3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark._
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle._
+import org.apache.spark.shuffle.sort.SortShuffleManager
+
+/**
+ * Subclass of [[BaseShuffleHandle]], used to identify when we've chosen to use the new shuffle.
+ */
+private[spark] class UnsafeShuffleHandle[K, V](
+    shuffleId: Int,
+    numMaps: Int,
+    dependency: ShuffleDependency[K, V, V])
+  extends BaseShuffleHandle(shuffleId, numMaps, dependency) {
+}
+
+private[spark] object UnsafeShuffleManager extends Logging {
+
+  /**
+   * The maximum number of shuffle output partitions that UnsafeShuffleManager supports.
+   */
+  val MAX_SHUFFLE_OUTPUT_PARTITIONS = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1
+
+  /**
+   * Helper method for determining whether a shuffle should use the optimized unsafe shuffle
+   * path or whether it should fall back to the original sort-based shuffle.
+   */
+  def canUseUnsafeShuffle[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = {
+    val shufId = dependency.shuffleId
+    val serializer = Serializer.getSerializer(dependency.serializer)
+    if (!serializer.supportsRelocationOfSerializedObjects) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because the serializer, " +
+        s"${serializer.getClass.getName}, does not support object relocation")
+      false
+    } else if (dependency.aggregator.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because an aggregator is defined")
+      false
+    } else if (dependency.keyOrdering.isDefined) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because a key ordering is defined")
+      false
+    } else if (dependency.partitioner.numPartitions > MAX_SHUFFLE_OUTPUT_PARTITIONS) {
+      log.debug(s"Can't use UnsafeShuffle for shuffle $shufId because it has more than " +
+        s"$MAX_SHUFFLE_OUTPUT_PARTITIONS partitions")
+      false
+    } else {
+      log.debug(s"Can use UnsafeShuffle for shuffle $shufId")
+      true
+    }
+  }
+}
+
+/**
+ * A shuffle implementation that uses directly-managed memory to implement several performance
+ * optimizations for certain types of shuffles. In cases where the new performance optimizations
+ * cannot be applied, this shuffle manager delegates to [[SortShuffleManager]] to handle those
+ * shuffles.
+ *
+ * UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:
+ *
+ *  - The shuffle dependency specifies no aggregation or output ordering.
+ *  - The shuffle serializer supports relocation of serialized values (this is currently supported
+ *    by KryoSerializer and Spark SQL's custom serializers).
+ *  - The shuffle produces fewer than 16777216 output partitions.
+ *  - No individual record is larger than 128 MB when serialized.
+ *
+ * In addition, extra spill-merging optimizations are automatically applied when the shuffle
+ * compression codec supports concatenation of serialized streams. This is currently supported by
+ * Spark's LZF serializer.
+ *
+ * At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.
+ * In sort-based shuffle, incoming records are sorted according to their target partition ids, then
+ * written to a single map output file. Reducers fetch contiguous regions of this file in order to
+ * read their portion of the map output. In cases where the map output data is too large to fit in
+ * memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged
+ * to produce the final output file.
+ *
+ * UnsafeShuffleManager optimizes this process in several ways:
+ *
+ *  - Its sort operates on serialized binary data rather than Java objects, which reduces memory
+ *    consumption and GC overheads. This optimization requires the record serializer to have certain
+ *    properties to allow serialized records to be re-ordered without requiring deserialization.
+ *    See SPARK-4550, where this optimization was first proposed and implemented, for more details.
+ *
+ *  - It uses a specialized cache-efficient sorter ([[UnsafeShuffleExternalSorter]]) that sorts
+ *    arrays of compressed record pointers and partition ids. By using only 8 bytes of space per
+ *    record in the sorting array, this fits more of the array into cache.
+ *
+ *  - The spill merging procedure operates on blocks of serialized records that belong to the same
+ *    partition and does not need to deserialize records during the merge.
+ *
+ *  - When the spill compression codec supports concatenation of compressed data, the spill merge
+ *    simply concatenates the serialized and compressed spill partitions to produce the final output
+ *    partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used
+ *    and avoids the need to allocate decompression or copying buffers during the merge.
+ *
+ * For more details on UnsafeShuffleManager's design, see SPARK-7081.
+ */
+private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging {
+
+  if (!conf.getBoolean("spark.shuffle.spill", true)) {
+    logWarning(
+      "spark.shuffle.spill was set to false, but this is ignored by the tungsten-sort shuffle " +
+      "manager; its optimized shuffles will continue to spill to disk when necessary.")
+  }
+
+  private[this] val sortShuffleManager: SortShuffleManager = new SortShuffleManager(conf)
+  private[this] val shufflesThatFellBackToSortShuffle =
+    Collections.newSetFromMap(new ConcurrentHashMap[Int, java.lang.Boolean]())
+  private[this] val numMapsForShufflesThatUsedNewPath = new ConcurrentHashMap[Int, Int]()
+
+  /**
+   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+   */
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    if (UnsafeShuffleManager.canUseUnsafeShuffle(dependency)) {
+      new UnsafeShuffleHandle[K, V](
+        shuffleId, numMaps, dependency.asInstanceOf[ShuffleDependency[K, V, V]])
+    } else {
+      new BaseShuffleHandle(shuffleId, numMaps, dependency)
+    }
+  }
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C] = {
+    sortShuffleManager.getReader(handle, startPartition, endPartition, context)
+  }
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  override def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext): ShuffleWriter[K, V] = {
+    handle match {
+      case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
+        numMapsForShufflesThatUsedNewPath.putIfAbsent(handle.shuffleId, unsafeShuffleHandle.numMaps)
+        val env = SparkEnv.get
+        new UnsafeShuffleWriter(
+          env.blockManager,
+          shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
+          context.taskMemoryManager(),
+          env.shuffleMemoryManager,
+          unsafeShuffleHandle,
+          mapId,
+          context,
+          env.conf)
+      case other =>
+        shufflesThatFellBackToSortShuffle.add(handle.shuffleId)
+        sortShuffleManager.getWriter(handle, mapId, context)
+    }
+  }
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  override def unregisterShuffle(shuffleId: Int): Boolean = {
+    if (shufflesThatFellBackToSortShuffle.remove(shuffleId)) {
+      sortShuffleManager.unregisterShuffle(shuffleId)
+    } else {
+      Option(numMapsForShufflesThatUsedNewPath.remove(shuffleId)).foreach { numMaps =>
+        (0 until numMaps).foreach { mapId =>
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId)
+        }
+      }
+      true
+    }
+  }
+
+  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
+    sortShuffleManager.shuffleBlockResolver
+  }
+
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {
+    sortShuffleManager.stop()
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 8bc4e20..a33f22e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -86,16 +86,6 @@ private[spark] class DiskBlockObjectWriter(
   extends BlockObjectWriter(blockId)
   with Logging
 {
-  /** Intercepts write calls and tracks total time spent writing. Not thread safe. */
-  private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
-    override def write(i: Int): Unit = callWithTiming(out.write(i))
-    override def write(b: Array[Byte]): Unit = callWithTiming(out.write(b))
-    override def write(b: Array[Byte], off: Int, len: Int): Unit = {
-      callWithTiming(out.write(b, off, len))
-    }
-    override def close(): Unit = out.close()
-    override def flush(): Unit = out.flush()
-  }
 
   /** The file channel, used for repositioning / truncating the file. */
   private var channel: FileChannel = null
@@ -136,7 +126,7 @@ private[spark] class DiskBlockObjectWriter(
       throw new IllegalStateException("Writer already closed. Cannot be reopened.")
     }
     fos = new FileOutputStream(file, true)
-    ts = new TimeTrackingOutputStream(fos)
+    ts = new TimeTrackingOutputStream(writeMetrics, fos)
     channel = fos.getChannel()
     bs = compressStream(new BufferedOutputStream(ts, bufferSize))
     objOut = serializerInstance.serializeStream(bs)
@@ -150,9 +140,9 @@ private[spark] class DiskBlockObjectWriter(
         if (syncWrites) {
           // Force outstanding writes to disk and track how long it takes
           objOut.flush()
-          callWithTiming {
-            fos.getFD.sync()
-          }
+          val start = System.nanoTime()
+          fos.getFD.sync()
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
         }
       } {
         objOut.close()
@@ -251,12 +241,6 @@ private[spark] class DiskBlockObjectWriter(
     reportedPosition = pos
   }
 
-  private def callWithTiming(f: => Unit) = {
-    val start = System.nanoTime()
-    f
-    writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
-  }
-
   // For testing
   private[spark] override def flush() {
     objOut.flush()

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index b850973..df2d6ad 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -90,7 +90,7 @@ class ExternalAppendOnlyMap[K, V, C](
   // Number of bytes spilled in total
   private var _diskBytesSpilled = 0L
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = 
     sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 7d5cf7b..3b9d14f 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -110,7 +110,7 @@ private[spark] class ExternalSorter[K, V, C](
   private val conf = SparkEnv.get.conf
   private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
   
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
   private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
new file mode 100644
index 0000000..db9e827
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/PackedRecordPointerSuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import static org.apache.spark.shuffle.unsafe.PackedRecordPointer.*;
+
+public class PackedRecordPointerSuite {
+
+  @Test
+  public void heap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void offHeap() {
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock page0 = memoryManager.allocatePage(100);
+    final MemoryBlock page1 = memoryManager.allocatePage(100);
+    final long addressInPage1 = memoryManager.encodePageNumberAndOffset(page1,
+      page1.getBaseOffset() + 42);
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(addressInPage1, 360));
+    assertEquals(360, packedPointer.getPartitionId());
+    final long recordPointer = packedPointer.getRecordPointer();
+    assertEquals(1, TaskMemoryManager.decodePageNumber(recordPointer));
+    assertEquals(page1.getBaseOffset() + 42, memoryManager.getOffsetInPage(recordPointer));
+    assertEquals(addressInPage1, recordPointer);
+    memoryManager.cleanUpAllAllocatedMemory();
+  }
+
+  @Test
+  public void maximumPartitionIdCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID));
+    assertEquals(MAXIMUM_PARTITION_ID, packedPointer.getPartitionId());
+  }
+
+  @Test
+  public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    try {
+      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
+      assertFalse(MAXIMUM_PARTITION_ID  + 1 == packedPointer.getPartitionId());
+    } catch (AssertionError e ) {
+      // pass
+    }
+  }
+
+  @Test
+  public void maximumOffsetInPageCanBeEncoded() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES - 1);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(address, packedPointer.getRecordPointer());
+  }
+
+  @Test
+  public void offsetsPastMaxOffsetInPageWillOverflow() {
+    PackedRecordPointer packedPointer = new PackedRecordPointer();
+    long address = TaskMemoryManager.encodePageNumberAndOffset(0, MAXIMUM_PAGE_SIZE_BYTES);
+    packedPointer.set(PackedRecordPointer.packPointer(address, 0));
+    assertEquals(0, packedPointer.getRecordPointer());
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
new file mode 100644
index 0000000..8fa7259
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorterSuite.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Arrays;
+import java.util.Random;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+public class UnsafeShuffleInMemorySorterSuite {
+
+  private static String getStringFromDataPage(Object baseObject, long baseOffset, int strLength) {
+    final byte[] strBytes = new byte[strLength];
+    PlatformDependent.copyMemory(
+      baseObject,
+      baseOffset,
+      strBytes,
+      PlatformDependent.BYTE_ARRAY_OFFSET, strLength);
+    return new String(strBytes);
+  }
+
+  @Test
+  public void testSortingEmptyInput() {
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(100);
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    assert(!iter.hasNext());
+  }
+
+  @Test
+  public void testBasicSorting() throws Exception {
+    final String[] dataToSort = new String[] {
+      "Boba",
+      "Pearls",
+      "Tapioca",
+      "Taho",
+      "Condensed Milk",
+      "Jasmine",
+      "Milk Tea",
+      "Lychee",
+      "Mango"
+    };
+    final TaskMemoryManager memoryManager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = memoryManager.allocatePage(2048);
+    final Object baseObject = dataPage.getBaseObject();
+    final UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    final HashPartitioner hashPartitioner = new HashPartitioner(4);
+
+    // Write the records into the data page and store pointers into the sorter
+    long position = dataPage.getBaseOffset();
+    for (String str : dataToSort) {
+      final long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage, position);
+      final byte[] strBytes = str.getBytes("utf-8");
+      PlatformDependent.UNSAFE.putInt(baseObject, position, strBytes.length);
+      position += 4;
+      PlatformDependent.copyMemory(
+        strBytes,
+        PlatformDependent.BYTE_ARRAY_OFFSET,
+        baseObject,
+        position,
+        strBytes.length);
+      position += strBytes.length;
+      sorter.insertRecord(recordAddress, hashPartitioner.getPartition(str));
+    }
+
+    // Sort the records
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int prevPartitionId = -1;
+    Arrays.sort(dataToSort);
+    for (int i = 0; i < dataToSort.length; i++) {
+      Assert.assertTrue(iter.hasNext());
+      iter.loadNext();
+      final int partitionId = iter.packedRecordPointer.getPartitionId();
+      Assert.assertTrue(partitionId >= 0 && partitionId <= 3);
+      Assert.assertTrue("Partition id " + partitionId + " should be >= prev id " + prevPartitionId,
+        partitionId >= prevPartitionId);
+      final long recordAddress = iter.packedRecordPointer.getRecordPointer();
+      final int recordLength = PlatformDependent.UNSAFE.getInt(
+        memoryManager.getPage(recordAddress), memoryManager.getOffsetInPage(recordAddress));
+      final String str = getStringFromDataPage(
+        memoryManager.getPage(recordAddress),
+        memoryManager.getOffsetInPage(recordAddress) + 4, // skip over record length
+        recordLength);
+      Assert.assertTrue(Arrays.binarySearch(dataToSort, str) != -1);
+    }
+    Assert.assertFalse(iter.hasNext());
+  }
+
+  @Test
+  public void testSortingManyNumbers() throws Exception {
+    UnsafeShuffleInMemorySorter sorter = new UnsafeShuffleInMemorySorter(4);
+    int[] numbersToSort = new int[128000];
+    Random random = new Random(16);
+    for (int i = 0; i < numbersToSort.length; i++) {
+      numbersToSort[i] = random.nextInt(PackedRecordPointer.MAXIMUM_PARTITION_ID + 1);
+      sorter.insertRecord(0, numbersToSort[i]);
+    }
+    Arrays.sort(numbersToSort);
+    int[] sorterResult = new int[numbersToSort.length];
+    UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator iter = sorter.getSortedIterator();
+    int j = 0;
+    while (iter.hasNext()) {
+      iter.loadNext();
+      sorterResult[j] = iter.packedRecordPointer.getPartitionId();
+      j += 1;
+    }
+    Assert.assertArrayEquals(numbersToSort, sorterResult);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
new file mode 100644
index 0000000..730d265
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.ByteBuffer;
+import java.util.*;
+
+import scala.*;
+import scala.collection.Iterator;
+import scala.reflect.ClassTag;
+import scala.runtime.AbstractFunction1;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.io.ByteStreams;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import static org.hamcrest.MatcherAssert.assertThat;
+import static org.hamcrest.Matchers.greaterThan;
+import static org.hamcrest.Matchers.lessThan;
+import static org.junit.Assert.*;
+import static org.mockito.AdditionalAnswers.returnsFirstArg;
+import static org.mockito.Answers.RETURNS_SMART_NULLS;
+import static org.mockito.Mockito.*;
+
+import org.apache.spark.*;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZ4CompressionCodec;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.io.SnappyCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.serializer.*;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
+import org.apache.spark.unsafe.memory.MemoryAllocator;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+public class UnsafeShuffleWriterSuite {
+
+  static final int NUM_PARTITITONS = 4;
+  final TaskMemoryManager taskMemoryManager =
+    new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+  final HashPartitioner hashPartitioner = new HashPartitioner(NUM_PARTITITONS);
+  File mergedOutputFile;
+  File tempDir;
+  long[] partitionSizesInMergedFile;
+  final LinkedList<File> spillFilesCreated = new LinkedList<File>();
+  SparkConf conf;
+  final Serializer serializer = new KryoSerializer(new SparkConf());
+  TaskMetrics taskMetrics;
+
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleMemoryManager shuffleMemoryManager;
+  @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) IndexShuffleBlockResolver shuffleBlockResolver;
+  @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager;
+  @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext;
+  @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep;
+
+  private final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
+    @Override
+    public OutputStream apply(OutputStream stream) {
+      if (conf.getBoolean("spark.shuffle.compress", true)) {
+        return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream);
+      } else {
+        return stream;
+      }
+    }
+  }
+
+  @After
+  public void tearDown() {
+    Utils.deleteRecursively(tempDir);
+    final long leakedMemory = taskMemoryManager.cleanUpAllAllocatedMemory();
+    if (leakedMemory != 0) {
+      fail("Test leaked " + leakedMemory + " bytes of managed memory");
+    }
+  }
+
+  @Before
+  @SuppressWarnings("unchecked")
+  public void setUp() throws IOException {
+    MockitoAnnotations.initMocks(this);
+    tempDir = Utils.createTempDir("test", "test");
+    mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
+    partitionSizesInMergedFile = null;
+    spillFilesCreated.clear();
+    conf = new SparkConf();
+    taskMetrics = new TaskMetrics();
+
+    when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
+
+    when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
+    when(blockManager.getDiskWriter(
+      any(BlockId.class),
+      any(File.class),
+      any(SerializerInstance.class),
+      anyInt(),
+      any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
+      @Override
+      public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
+        Object[] args = invocationOnMock.getArguments();
+
+        return new DiskBlockObjectWriter(
+          (BlockId) args[0],
+          (File) args[1],
+          (SerializerInstance) args[2],
+          (Integer) args[3],
+          new CompressStream(),
+          false,
+          (ShuffleWriteMetrics) args[4]
+        );
+      }
+    });
+    when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))).thenAnswer(
+      new Answer<InputStream>() {
+        @Override
+        public InputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          InputStream is = (InputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(is);
+          } else {
+            return is;
+          }
+        }
+      }
+    );
+
+    when(blockManager.wrapForCompression(any(BlockId.class), any(OutputStream.class))).thenAnswer(
+      new Answer<OutputStream>() {
+        @Override
+        public OutputStream answer(InvocationOnMock invocation) throws Throwable {
+          assert (invocation.getArguments()[0] instanceof TempShuffleBlockId);
+          OutputStream os = (OutputStream) invocation.getArguments()[1];
+          if (conf.getBoolean("spark.shuffle.compress", true)) {
+            return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(os);
+          } else {
+            return os;
+          }
+        }
+      }
+    );
+
+    when(shuffleBlockResolver.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
+    doAnswer(new Answer<Void>() {
+      @Override
+      public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
+        partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+        return null;
+      }
+    }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer<Tuple2<TempShuffleBlockId, File>>() {
+        @Override
+        public Tuple2<TempShuffleBlockId, File> answer(
+          InvocationOnMock invocationOnMock) throws Throwable {
+          TempShuffleBlockId blockId = new TempShuffleBlockId(UUID.randomUUID());
+          File file = File.createTempFile("spillFile", ".spill", tempDir);
+          spillFilesCreated.add(file);
+          return Tuple2$.MODULE$.apply(blockId, file);
+        }
+      });
+
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics);
+
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+    when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
+  }
+
+  private UnsafeShuffleWriter<Object, Object> createWriter(
+      boolean transferToEnabled) throws IOException {
+    conf.set("spark.file.transferTo", String.valueOf(transferToEnabled));
+    return new UnsafeShuffleWriter<Object, Object>(
+      blockManager,
+      shuffleBlockResolver,
+      taskMemoryManager,
+      shuffleMemoryManager,
+      new UnsafeShuffleHandle<Object, Object>(0, 1, shuffleDep),
+      0, // map id
+      taskContext,
+      conf
+    );
+  }
+
+  private void assertSpillFilesWereCleanedUp() {
+    for (File spillFile : spillFilesCreated) {
+      assertFalse("Spill file " + spillFile.getPath() + " was not cleaned up",
+        spillFile.exists());
+    }
+  }
+
+  private List<Tuple2<Object, Object>> readRecordsFromFile() throws IOException {
+    final ArrayList<Tuple2<Object, Object>> recordsList = new ArrayList<Tuple2<Object, Object>>();
+    long startOffset = 0;
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      final long partitionSize = partitionSizesInMergedFile[i];
+      if (partitionSize > 0) {
+        InputStream in = new FileInputStream(mergedOutputFile);
+        ByteStreams.skipFully(in, startOffset);
+        in = new LimitedInputStream(in, partitionSize);
+        if (conf.getBoolean("spark.shuffle.compress", true)) {
+          in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in);
+        }
+        DeserializationStream recordsStream = serializer.newInstance().deserializeStream(in);
+        Iterator<Tuple2<Object, Object>> records = recordsStream.asKeyValueIterator();
+        while (records.hasNext()) {
+          Tuple2<Object, Object> record = records.next();
+          assertEquals(i, hashPartitioner.getPartition(record._1()));
+          recordsList.add(record);
+        }
+        recordsStream.close();
+        startOffset += partitionSize;
+      }
+    }
+    return recordsList;
+  }
+
+  @Test(expected=IllegalStateException.class)
+  public void mustCallWriteBeforeSuccessfulStop() throws IOException {
+    createWriter(false).stop(true);
+  }
+
+  @Test
+  public void doNotNeedToCallWriteBeforeUnsuccessfulStop() throws IOException {
+    createWriter(false).stop(false);
+  }
+
+  @Test
+  public void writeEmptyIterator() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(Collections.<Product2<Object, Object>>emptyIterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile);
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.shuffleWriteMetrics().get().shuffleBytesWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+  }
+
+  @Test
+  public void writeWithoutSpilling() throws Exception {
+    // In this example, each partition should have exactly one record:
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < NUM_PARTITITONS; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
+    writer.write(dataToWrite.iterator());
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      // All partitions should be the same size:
+      assertEquals(partitionSizesInMergedFile[0], size);
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertEquals(0, taskMetrics.diskBytesSpilled());
+    assertEquals(0, taskMetrics.memoryBytesSpilled());
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  private void testMergingSpills(
+      boolean transferToEnabled,
+      String compressionCodecName) throws IOException {
+    if (compressionCodecName != null) {
+      conf.set("spark.shuffle.compress", "true");
+      conf.set("spark.io.compression.codec", compressionCodecName);
+    } else {
+      conf.set("spark.shuffle.compress", "false");
+    }
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.insertRecordIntoSorter(dataToWrite.get(0));
+    writer.insertRecordIntoSorter(dataToWrite.get(1));
+    writer.insertRecordIntoSorter(dataToWrite.get(2));
+    writer.insertRecordIntoSorter(dataToWrite.get(3));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(dataToWrite.get(4));
+    writer.insertRecordIntoSorter(dataToWrite.get(5));
+    writer.closeAndWriteOutput();
+    final Option<MapStatus> mapStatus = writer.stop(true);
+    assertTrue(mapStatus.isDefined());
+    assertTrue(mergedOutputFile.exists());
+    assertEquals(2, spillFilesCreated.size());
+
+    long sumOfPartitionSizes = 0;
+    for (long size: partitionSizesInMergedFile) {
+      sumOfPartitionSizes += size;
+    }
+    assertEquals(sumOfPartitionSizes, mergedOutputFile.length());
+
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZF() throws Exception {
+    testMergingSpills(true, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZF() throws Exception {
+    testMergingSpills(false, LZFCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndLZ4() throws Exception {
+    testMergingSpills(true, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndLZ4() throws Exception {
+    testMergingSpills(false, LZ4CompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndSnappy() throws Exception {
+    testMergingSpills(true, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndSnappy() throws Exception {
+    testMergingSpills(false, SnappyCompressionCodec.class.getName());
+  }
+
+  @Test
+  public void mergeSpillsWithTransferToAndNoCompression() throws Exception {
+    testMergingSpills(true, null);
+  }
+
+  @Test
+  public void mergeSpillsWithFileStreamAndNoCompression() throws Exception {
+    testMergingSpills(false, null);
+  }
+
+  @Test
+  public void writeEnoughDataToTriggerSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to allocate new data page
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    final byte[] bigByteArray = new byte[PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES / 128];
+    for (int i = 0; i < 128 + 1; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, bigByteArray));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exception {
+    when(shuffleMemoryManager.tryToAcquire(anyLong()))
+      .then(returnsFirstArg()) // Allocate initial sort buffer
+      .then(returnsFirstArg()) // Allocate initial data page
+      .thenReturn(0L) // Deny request to grow sort buffer
+      .then(returnsFirstArg());  // Grant new sort buffer and data page.
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<Product2<Object, Object>>();
+    for (int i = 0; i < UnsafeShuffleWriter.INITIAL_SORT_BUFFER_SIZE; i++) {
+      dataToWrite.add(new Tuple2<Object, Object>(i, i));
+    }
+    writer.write(dataToWrite.iterator());
+    verify(shuffleMemoryManager, times(5)).tryToAcquire(anyLong());
+    assertEquals(2, spillFilesCreated.size());
+    writer.stop(true);
+    readRecordsFromFile();
+    assertSpillFilesWereCleanedUp();
+    ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get();
+    assertEquals(dataToWrite.size(), shuffleWriteMetrics.shuffleRecordsWritten());
+    assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L));
+    assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length()));
+    assertThat(taskMetrics.memoryBytesSpilled(), greaterThan(0L));
+    assertEquals(mergedOutputFile.length(), shuffleWriteMetrics.shuffleBytesWritten());
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    final ArrayList<Product2<Object, Object>> dataToWrite =
+      new ArrayList<Product2<Object, Object>>();
+    final byte[] bytes = new byte[(int) (UnsafeShuffleExternalSorter.DISK_WRITE_BUFFER_SIZE * 2.5)];
+    new Random(42).nextBytes(bytes);
+    dataToWrite.add(new Tuple2<Object, Object>(1, ByteBuffer.wrap(bytes)));
+    writer.write(dataToWrite.iterator());
+    writer.stop(true);
+    assertEquals(
+      HashMultiset.create(dataToWrite),
+      HashMultiset.create(readRecordsFromFile()));
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
+    // Use a custom serializer so that we have exact control over the size of serialized data.
+    final Serializer byteArraySerializer = new Serializer() {
+      @Override
+      public SerializerInstance newInstance() {
+        return new SerializerInstance() {
+          @Override
+          public SerializationStream serializeStream(final OutputStream s) {
+            return new SerializationStream() {
+              @Override
+              public void flush() { }
+
+              @Override
+              public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+                byte[] bytes = (byte[]) t;
+                try {
+                  s.write(bytes);
+                } catch (IOException e) {
+                  throw new RuntimeException(e);
+                }
+                return this;
+              }
+
+              @Override
+              public void close() { }
+            };
+          }
+          public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
+          public DeserializationStream deserializeStream(InputStream s) { return null; }
+          public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
+          public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
+        };
+      }
+    };
+    when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    // Insert a record and force a spill so that there's something to clean up:
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
+    writer.forceSorterToSpill();
+    // We should be able to write a record that's right _at_ the max record size
+    final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
+    new Random(42).nextBytes(atMaxRecordSize);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
+    writer.forceSorterToSpill();
+    // Inserting a record that's larger than the max record size should fail:
+    final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
+    new Random(42).nextBytes(exceedsMaxRecordSize);
+    Product2<Object, Object> hugeRecord =
+      new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
+    try {
+      // Here, we write through the public `write()` interface instead of the test-only
+      // `insertRecordIntoSorter` interface:
+      writer.write(Collections.singletonList(hugeRecord).iterator());
+      fail("Expected exception to be thrown");
+    } catch (IOException e) {
+      // Pass
+    }
+    assertSpillFilesWereCleanedUp();
+  }
+
+  @Test
+  public void spillFilesAreDeletedWhenStoppingAfterError() throws IOException {
+    final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.forceSorterToSpill();
+    writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
+    writer.stop(false);
+    assertSpillFilesWereCleanedUp();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index 8c6035f..cf6a143 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.io
 
 import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
 
+import com.google.common.io.ByteStreams
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkConf
@@ -62,6 +63,14 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lz4 does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName)
+    assert(codec.getClass === classOf[LZ4CompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("lzf compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
     assert(codec.getClass === classOf[LZFCompressionCodec])
@@ -74,6 +83,12 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("lzf supports concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName)
+    assert(codec.getClass === classOf[LZFCompressionCodec])
+    testConcatenationOfSerializedStreams(codec)
+  }
+
   test("snappy compression codec") {
     val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
     assert(codec.getClass === classOf[SnappyCompressionCodec])
@@ -86,9 +101,38 @@ class CompressionCodecSuite extends FunSuite {
     testCodec(codec)
   }
 
+  test("snappy does not support concatenation of serialized streams") {
+    val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
+    assert(codec.getClass === classOf[SnappyCompressionCodec])
+    intercept[Exception] {
+      testConcatenationOfSerializedStreams(codec)
+    }
+  }
+
   test("bad compression codec") {
     intercept[IllegalArgumentException] {
       CompressionCodec.createCodec(conf, "foobar")
     }
   }
+
+  private def testConcatenationOfSerializedStreams(codec: CompressionCodec): Unit = {
+    val bytes1: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (0 to 64).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val bytes2: Array[Byte] = {
+      val baos = new ByteArrayOutputStream()
+      val out = codec.compressedOutputStream(baos)
+      (65 to 127).foreach(out.write)
+      out.close()
+      baos.toByteArray
+    }
+    val concatenatedBytes = codec.compressedInputStream(new ByteArrayInputStream(bytes1 ++ bytes2))
+    val decompressed: Array[Byte] = new Array[Byte](128)
+    ByteStreams.readFully(concatenatedBytes, decompressed)
+    assert(decompressed.toSeq === (0 to 127))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
new file mode 100644
index 0000000..ed4d8ce
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import org.apache.spark.SparkConf
+import org.scalatest.FunSuite
+
+class JavaSerializerSuite extends FunSuite {
+  test("JavaSerializer instances are serializable") {
+    val serializer = new JavaSerializer(new SparkConf())
+    val instance = serializer.newInstance()
+    instance.deserialize[JavaSerializer](instance.serialize(serializer))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
new file mode 100644
index 0000000..49a04a2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManagerSuite.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{FunSuite, Matchers}
+
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, Serializer}
+
+/**
+ * Tests for the fallback logic in UnsafeShuffleManager. Actual tests of shuffling data are
+ * performed in other suites.
+ */
+class UnsafeShuffleManagerSuite extends FunSuite with Matchers {
+
+  import UnsafeShuffleManager.canUseUnsafeShuffle
+
+  private class RuntimeExceptionAnswer extends Answer[Object] {
+    override def answer(invocation: InvocationOnMock): Object = {
+      throw new RuntimeException("Called non-stubbed method, " + invocation.getMethod.getName)
+    }
+  }
+
+  private def shuffleDep(
+      partitioner: Partitioner,
+      serializer: Option[Serializer],
+      keyOrdering: Option[Ordering[Any]],
+      aggregator: Option[Aggregator[Any, Any, Any]],
+      mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
+    val dep = mock(classOf[ShuffleDependency[Any, Any, Any]], new RuntimeExceptionAnswer())
+    doReturn(0).when(dep).shuffleId
+    doReturn(partitioner).when(dep).partitioner
+    doReturn(serializer).when(dep).serializer
+    doReturn(keyOrdering).when(dep).keyOrdering
+    doReturn(aggregator).when(dep).aggregator
+    doReturn(mapSideCombine).when(dep).mapSideCombine
+    dep
+  }
+
+  test("supported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    val rangePartitioner = mock(classOf[RangePartitioner[Any, Any]])
+    when(rangePartitioner.numPartitions).thenReturn(2)
+    assert(canUseUnsafeShuffle(shuffleDep(
+      partitioner = rangePartitioner,
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+  }
+
+  test("unsupported shuffle dependencies") {
+    val kryo = Some(new KryoSerializer(new SparkConf()))
+    val java = Some(new JavaSerializer(new SparkConf()))
+
+    // We only support serializers that support object relocation
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = java,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles with more than 16 million output partitions
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS + 1),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = None,
+      mapSideCombine = false
+    )))
+
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = None,
+      mapSideCombine = false
+    )))
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = None,
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = false
+    )))
+    // We do not support shuffles that perform any kind of aggregation or sorting of keys
+    assert(!canUseUnsafeShuffle(shuffleDep(
+      partitioner = new HashPartitioner(2),
+      serializer = kryo,
+      keyOrdering = Some(mock(classOf[Ordering[Any]])),
+      aggregator = Some(mock(classOf[Aggregator[Any, Any, Any]])),
+      mapSideCombine = true
+    )))
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
new file mode 100644
index 0000000..6351539
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleSuite.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe
+
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.{HashPartitioner, ShuffleDependency, SparkContext, ShuffleSuite}
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
+class UnsafeShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+  // This test suite should run all tests in ShuffleSuite with unsafe-based shuffle.
+
+  override def beforeAll() {
+    conf.set("spark.shuffle.manager", "tungsten-sort")
+    // UnsafeShuffleManager requires at least 128 MB of memory per task in order to be able to sort
+    // shuffle records.
+    conf.set("spark.shuffle.memoryFraction", "0.5")
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the new shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the new UnsafeShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new KryoSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+
+  test("UnsafeShuffleManager properly cleans up files for shuffles that use the old shuffle path") {
+    val tmpDir = Utils.createTempDir()
+    try {
+      val myConf = conf.clone()
+        .set("spark.local.dir", tmpDir.getAbsolutePath)
+      sc = new SparkContext("local", "test", myConf)
+      // Create a shuffled RDD and verify that it will actually use the old SortShuffle path
+      val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+      val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+        .setSerializer(new JavaSerializer(myConf))
+      val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+      assert(!UnsafeShuffleManager.canUseUnsafeShuffle(shuffleDep))
+      def getAllFiles: Set[File] =
+        FileUtils.listFiles(tmpDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+      val filesBeforeShuffle = getAllFiles
+      // Force the shuffle to be performed
+      shuffledRdd.count()
+      // Ensure that the shuffle actually created files that will need to be cleaned up
+      val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+      filesCreatedByShuffle.map(_.getName) should be
+        Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+      // Check that the cleanup actually removes the files
+      sc.env.blockManager.master.removeShuffle(0, blocking = true)
+      for (file <- filesCreatedByShuffle) {
+        assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+      }
+    } finally {
+      Utils.deleteRecursively(tmpDir)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/pom.xml
----------------------------------------------------------------------
diff --git a/pom.xml b/pom.xml
index cf9279e..564a443 100644
--- a/pom.xml
+++ b/pom.xml
@@ -669,7 +669,7 @@
       <dependency>
         <groupId>org.mockito</groupId>
         <artifactId>mockito-all</artifactId>
-        <version>1.9.0</version>
+        <version>1.9.5</version>
         <scope>test</scope>
       </dependency>
       <dependency>
@@ -685,6 +685,18 @@
         <scope>test</scope>
       </dependency>
       <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-core</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
+        <groupId>org.hamcrest</groupId>
+        <artifactId>hamcrest-library</artifactId>
+        <version>1.3</version>
+        <scope>test</scope>
+      </dependency>
+      <dependency>
         <groupId>com.novocode</groupId>
         <artifactId>junit-interface</artifactId>
         <version>0.10</version>

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fba7290..487062a 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -131,6 +131,12 @@ object MimaExcludes {
             // SPARK-7530 Added StreamingContext.getState()
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.streaming.StreamingContext.state_=")
+          ) ++ Seq(
+            // SPARK-7081 changed ShuffleWriter from a trait to an abstract class and removed some
+            // unnecessary type bounds in order to fix some compiler warnings that occurred when
+            // implementing this interface in Java. Note that ShuffleWriter is private[spark].
+            ProblemFilters.exclude[IncompatibleTemplateDefProblem](
+              "org.apache.spark.shuffle.ShuffleWriter")
           )
 
         case v if v.startsWith("1.3") =>

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index c3d2c70..3e46596 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -17,17 +17,18 @@
 
 package org.apache.spark.sql.execution
 
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.shuffle.sort.SortShuffleManager
 import org.apache.spark.{HashPartitioner, Partitioner, RangePartitioner, SparkEnv}
+import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
 import org.apache.spark.serializer.Serializer
-import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.shuffle.unsafe.UnsafeShuffleManager
 import org.apache.spark.sql.catalyst.errors.attachTree
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.{SQLContext, Row}
 import org.apache.spark.util.MutablePair
 
 object Exchange {
@@ -85,7 +86,9 @@ case class Exchange(
     // corner-cases where a partitioner constructed with `numPartitions` partitions may output
     // fewer partitions (like RangePartitioner, for example).
     val conf = child.sqlContext.sparkContext.conf
-    val sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+    val shuffleManager = SparkEnv.get.shuffleManager
+    val sortBasedShuffleOn = shuffleManager.isInstanceOf[SortShuffleManager] ||
+      shuffleManager.isInstanceOf[UnsafeShuffleManager]
     val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
     val serializeMapOutputs = conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true)
     if (newOrdering.nonEmpty) {
@@ -93,11 +96,11 @@ case class Exchange(
       // which requires a defensive copy.
       true
     } else if (sortBasedShuffleOn) {
-      // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
-      // However, there are two special cases where we can avoid the copy, described below:
-      if (partitioner.numPartitions <= bypassMergeThreshold) {
-        // If the number of output partitions is sufficiently small, then Spark will fall back to
-        // the old hash-based shuffle write path which doesn't buffer deserialized records.
+      val bypassIsSupported = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
+      if (bypassIsSupported && partitioner.numPartitions <= bypassMergeThreshold) {
+        // If we're using the original SortShuffleManager and the number of output partitions is
+        // sufficiently small, then Spark will fall back to the hash-based shuffle write path, which
+        // doesn't buffer deserialized records.
         // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
         false
       } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
@@ -105,9 +108,14 @@ case class Exchange(
         // them. This optimization is guarded by a feature-flag and is only applied in cases where
         // shuffle dependency does not specify an ordering and the record serializer has certain
         // properties. If this optimization is enabled, we can safely avoid the copy.
+        //
+        // This optimization also applies to UnsafeShuffleManager (added in SPARK-7081).
         false
       } else {
-        // None of the special cases held, so we must copy.
+        // Spark's SortShuffleManager uses `ExternalSorter` to buffer records in memory. This code
+        // path is used both when SortShuffleManager is used and when UnsafeShuffleManager falls
+        // back to SortShuffleManager to perform a shuffle that the new fast path can't handle. In
+        // both cases, we must copy.
         true
       }
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/unsafe/pom.xml
----------------------------------------------------------------------
diff --git a/unsafe/pom.xml b/unsafe/pom.xml
index 5b07332..9e151fc 100644
--- a/unsafe/pom.xml
+++ b/unsafe/pom.xml
@@ -42,6 +42,10 @@
       <groupId>com.google.code.findbugs</groupId>
       <artifactId>jsr305</artifactId>
     </dependency>
+    <dependency>
+      <groupId>com.google.guava</groupId>
+      <artifactId>guava</artifactId>
+    </dependency>
 
     <!-- Provided dependencies -->
     <dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
----------------------------------------------------------------------
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
index 9224988..2906ac8 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/memory/TaskMemoryManager.java
@@ -19,6 +19,7 @@ package org.apache.spark.unsafe.memory;
 
 import java.util.*;
 
+import com.google.common.annotations.VisibleForTesting;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -47,10 +48,18 @@ public final class TaskMemoryManager {
 
   private final Logger logger = LoggerFactory.getLogger(TaskMemoryManager.class);
 
-  /**
-   * The number of entries in the page table.
-   */
-  private static final int PAGE_TABLE_SIZE = 1 << 13;
+  /** The number of bits used to address the page table. */
+  private static final int PAGE_NUMBER_BITS = 13;
+
+  /** The number of bits used to encode offsets in data pages. */
+  @VisibleForTesting
+  static final int OFFSET_BITS = 64 - PAGE_NUMBER_BITS;  // 51
+
+  /** The number of entries in the page table. */
+  private static final int PAGE_TABLE_SIZE = 1 << PAGE_NUMBER_BITS;
+
+  /** Maximum supported data page size */
+  private static final long MAXIMUM_PAGE_SIZE = (1L << OFFSET_BITS);
 
   /** Bit mask for the lower 51 bits of a long. */
   private static final long MASK_LONG_LOWER_51_BITS = 0x7FFFFFFFFFFFFL;
@@ -101,11 +110,9 @@ public final class TaskMemoryManager {
    * intended for allocating large blocks of memory that will be shared between operators.
    */
   public MemoryBlock allocatePage(long size) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Allocating {} byte page", size);
-    }
-    if (size >= (1L << 51)) {
-      throw new IllegalArgumentException("Cannot allocate a page with more than 2^51 bytes");
+    if (size > MAXIMUM_PAGE_SIZE) {
+      throw new IllegalArgumentException(
+        "Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE + " bytes");
     }
 
     final int pageNumber;
@@ -120,8 +127,8 @@ public final class TaskMemoryManager {
     final MemoryBlock page = executorMemoryManager.allocate(size);
     page.pageNumber = pageNumber;
     pageTable[pageNumber] = page;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Allocate page number {} ({} bytes)", pageNumber, size);
+    if (logger.isTraceEnabled()) {
+      logger.trace("Allocate page number {} ({} bytes)", pageNumber, size);
     }
     return page;
   }
@@ -130,9 +137,6 @@ public final class TaskMemoryManager {
    * Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
    */
   public void freePage(MemoryBlock page) {
-    if (logger.isTraceEnabled()) {
-      logger.trace("Freeing page number {} ({} bytes)", page.pageNumber, page.size());
-    }
     assert (page.pageNumber != -1) :
       "Called freePage() on memory that wasn't allocated with allocatePage()";
     executorMemoryManager.free(page);
@@ -140,8 +144,8 @@ public final class TaskMemoryManager {
       allocatedPages.clear(page.pageNumber);
     }
     pageTable[page.pageNumber] = null;
-    if (logger.isDebugEnabled()) {
-      logger.debug("Freed page number {} ({} bytes)", page.pageNumber, page.size());
+    if (logger.isTraceEnabled()) {
+      logger.trace("Freed page number {} ({} bytes)", page.pageNumber, page.size());
     }
   }
 
@@ -173,14 +177,36 @@ public final class TaskMemoryManager {
   /**
    * Given a memory page and offset within that page, encode this address into a 64-bit long.
    * This address will remain valid as long as the corresponding page has not been freed.
+   *
+   * @param page a data page allocated by {@link TaskMemoryManager#allocate(long)}.
+   * @param offsetInPage an offset in this page which incorporates the base offset. In other words,
+   *                     this should be the value that you would pass as the base offset into an
+   *                     UNSAFE call (e.g. page.baseOffset() + something).
+   * @return an encoded page address.
    */
   public long encodePageNumberAndOffset(MemoryBlock page, long offsetInPage) {
-    if (inHeap) {
-      assert (page.pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
-      return (((long) page.pageNumber) << 51) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
-    } else {
-      return offsetInPage;
+    if (!inHeap) {
+      // In off-heap mode, an offset is an absolute address that may require a full 64 bits to
+      // encode. Due to our page size limitation, though, we can convert this into an offset that's
+      // relative to the page's base offset; this relative offset will fit in 51 bits.
+      offsetInPage -= page.getBaseOffset();
     }
+    return encodePageNumberAndOffset(page.pageNumber, offsetInPage);
+  }
+
+  @VisibleForTesting
+  public static long encodePageNumberAndOffset(int pageNumber, long offsetInPage) {
+    assert (pageNumber != -1) : "encodePageNumberAndOffset called with invalid page";
+    return (((long) pageNumber) << OFFSET_BITS) | (offsetInPage & MASK_LONG_LOWER_51_BITS);
+  }
+
+  @VisibleForTesting
+  public static int decodePageNumber(long pagePlusOffsetAddress) {
+    return (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> OFFSET_BITS);
+  }
+
+  private static long decodeOffset(long pagePlusOffsetAddress) {
+    return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
   }
 
   /**
@@ -189,7 +215,7 @@ public final class TaskMemoryManager {
    */
   public Object getPage(long pagePlusOffsetAddress) {
     if (inHeap) {
-      final int pageNumber = (int) ((pagePlusOffsetAddress & MASK_LONG_UPPER_13_BITS) >>> 51);
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
       assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
       final Object page = pageTable[pageNumber].getBaseObject();
       assert (page != null);
@@ -204,10 +230,15 @@ public final class TaskMemoryManager {
    * {@link TaskMemoryManager#encodePageNumberAndOffset(MemoryBlock, long)}
    */
   public long getOffsetInPage(long pagePlusOffsetAddress) {
+    final long offsetInPage = decodeOffset(pagePlusOffsetAddress);
     if (inHeap) {
-      return (pagePlusOffsetAddress & MASK_LONG_LOWER_51_BITS);
+      return offsetInPage;
     } else {
-      return pagePlusOffsetAddress;
+      // In off-heap mode, an offset is an absolute address. In encodePageNumberAndOffset, we
+      // converted the absolute address into a relative address. Here, we invert that operation:
+      final int pageNumber = decodePageNumber(pagePlusOffsetAddress);
+      assert (pageNumber >= 0 && pageNumber < PAGE_TABLE_SIZE);
+      return pageTable[pageNumber].getBaseOffset() + offsetInPage;
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
----------------------------------------------------------------------
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
index 932882f..06fb081 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/memory/TaskMemoryManagerSuite.java
@@ -38,4 +38,27 @@ public class TaskMemoryManagerSuite {
     Assert.assertEquals(4096, manager.cleanUpAllAllocatedMemory());
   }
 
+  @Test
+  public void encodePageNumberAndOffsetOffHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.UNSAFE));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    // In off-heap mode, an offset is an absolute address that may require more than 51 bits to
+    // encode. This test exercises that corner-case:
+    final long offset = ((1L << TaskMemoryManager.OFFSET_BITS) + 10);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, offset);
+    Assert.assertEquals(null, manager.getPage(encodedAddress));
+    Assert.assertEquals(offset, manager.getOffsetInPage(encodedAddress));
+  }
+
+  @Test
+  public void encodePageNumberAndOffsetOnHeap() {
+    final TaskMemoryManager manager =
+      new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
+    final MemoryBlock dataPage = manager.allocatePage(256);
+    final long encodedAddress = manager.encodePageNumberAndOffset(dataPage, 64);
+    Assert.assertEquals(dataPage.getBaseObject(), manager.getPage(encodedAddress));
+    Assert.assertEquals(64, manager.getOffsetInPage(encodedAddress));
+  }
+
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-7081] Faster sort-based shuffle path using binary processing cache-aware sort

Posted by rx...@apache.org.
[SPARK-7081] Faster sort-based shuffle path using binary processing cache-aware sort

This patch introduces a new shuffle manager that enhances the existing sort-based shuffle with a new cache-friendly sort algorithm that operates directly on binary data. The goals of this patch are to lower memory usage and Java object overheads during shuffle and to speed up sorting. It also lays groundwork for follow-up patches that will enable end-to-end processing of serialized records.

The new shuffle manager, `UnsafeShuffleManager`, can be enabled by setting `spark.shuffle.manager=tungsten-sort` in SparkConf.

The new shuffle manager uses directly-managed memory to implement several performance optimizations for certain types of shuffles. In cases where the new performance optimizations cannot be applied, the new shuffle manager delegates to SortShuffleManager to handle those shuffles.

UnsafeShuffleManager's optimizations will apply when _all_ of the following conditions hold:

 - The shuffle dependency specifies no aggregation or output ordering.
 - The shuffle serializer supports relocation of serialized values (this is currently supported
   by KryoSerializer and Spark SQL's custom serializers).
 - The shuffle produces fewer than 16777216 output partitions.
 - No individual record is larger than 128 MB when serialized.

In addition, extra spill-merging optimizations are automatically applied when the shuffle compression codec supports concatenation of serialized streams. This is currently supported by Spark's LZF serializer.

At a high-level, UnsafeShuffleManager's design is similar to Spark's existing SortShuffleManager.  In sort-based shuffle, incoming records are sorted according to their target partition ids, then written to a single map output file. Reducers fetch contiguous regions of this file in order to read their portion of the map output. In cases where the map output data is too large to fit in memory, sorted subsets of the output can are spilled to disk and those on-disk files are merged to produce the final output file.

UnsafeShuffleManager optimizes this process in several ways:

 - Its sort operates on serialized binary data rather than Java objects, which reduces memory consumption and GC overheads. This optimization requires the record serializer to have certain properties to allow serialized records to be re-ordered without requiring deserialization.  See SPARK-4550, where this optimization was first proposed and implemented, for more details.

 - It uses a specialized cache-efficient sorter (UnsafeShuffleExternalSorter) that sorts arrays of compressed record pointers and partition ids. By using only 8 bytes of space per record in the sorting array, this fits more of the array into cache.

 - The spill merging procedure operates on blocks of serialized records that belong to the same partition and does not need to deserialize records during the merge.

 - When the spill compression codec supports concatenation of compressed data, the spill merge simply concatenates the serialized and compressed spill partitions to produce the final output partition.  This allows efficient data copying methods, like NIO's `transferTo`, to be used and avoids the need to allocate decompression or copying buffers during the merge.

The shuffle read path is unchanged.

This patch is similar to [SPARK-4550](http://issues.apache.org/jira/browse/SPARK-4550) / #4450 but uses a slightly different implementation. The `unsafe`-based implementation featured in this patch lays the groundwork for followup patches that will enable sorting to operate on serialized data pages that will be prepared by Spark SQL's new `unsafe` operators (such as the new aggregation operator introduced in #5725).

### Future work

There are several tasks that build upon this patch, which will be left to future work:

- [SPARK-7271](https://issues.apache.org/jira/browse/SPARK-7271) Redesign / extend the shuffle interfaces to accept binary data as input. The goal here is to let us bypass serialization steps in cases where the sort input is produced by an operator that operates directly on binary data.
- Extension / redesign of the `Serializer` API. We can add new methods which allow serializers to determine the size requirements for serializing objects and for serializing objects directly to a specified memory address (similar to how `UnsafeRowConverter` works in Spark SQL).

<!-- Reviewable:start -->
[<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/5868)
<!-- Reviewable:end -->

Author: Josh Rosen <jo...@databricks.com>

Closes #5868 from JoshRosen/unsafe-sort and squashes the following commits:

ef0a86e [Josh Rosen] Fix scalastyle errors
7610f2f [Josh Rosen] Add tests for proper cleanup of shuffle data.
d494ffe [Josh Rosen] Fix deserialization of JavaSerializer instances.
52a9981 [Josh Rosen] Fix some bugs in the address packing code.
51812a7 [Josh Rosen] Change shuffle manager sort name to tungsten-sort
4023fa4 [Josh Rosen] Add @Private annotation to some Java classes.
de40b9d [Josh Rosen] More comments to try to explain metrics code
df07699 [Josh Rosen] Attempt to clarify confusing metrics update code
5e189c6 [Josh Rosen] Track time spend closing / flushing files; split TimeTrackingOutputStream into separate file.
d5779c6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort
c2ce78e [Josh Rosen] Fix a missed usage of MAX_PARTITION_ID
e3b8855 [Josh Rosen] Cleanup in UnsafeShuffleWriter
4a2c785 [Josh Rosen] rename 'sort buffer' to 'pointer array'
6276168 [Josh Rosen] Remove ability to disable spilling in UnsafeShuffleExternalSorter.
57312c9 [Josh Rosen] Clarify fileBufferSize units
2d4e4f4 [Josh Rosen] Address some minor comments in UnsafeShuffleExternalSorter.
fdcac08 [Josh Rosen] Guard against overflow when expanding sort buffer.
85da63f [Josh Rosen] Cleanup in UnsafeShuffleSorterIterator.
0ad34da [Josh Rosen] Fix off-by-one in nextInt() call
56781a1 [Josh Rosen] Rename UnsafeShuffleSorter to UnsafeShuffleInMemorySorter
e995d1a [Josh Rosen] Introduce MAX_SHUFFLE_OUTPUT_PARTITIONS.
e58a6b4 [Josh Rosen] Add more tests for PackedRecordPointer encoding.
4f0b770 [Josh Rosen] Attempt to implement proper shuffle write metrics.
d4e6d89 [Josh Rosen] Update to bit shifting constants
69d5899 [Josh Rosen] Remove some unnecessary override vals
8531286 [Josh Rosen] Add tests that automatically trigger spills.
7c953f9 [Josh Rosen] Add test that covers UnsafeShuffleSortDataFormat.swap().
e1855e5 [Josh Rosen] Fix a handful of misc. IntelliJ inspections
39434f9 [Josh Rosen] Avoid integer multiplication overflow in getMemoryUsage (thanks FindBugs!)
1e3ad52 [Josh Rosen] Delete unused ByteBufferOutputStream class.
ea4f85f [Josh Rosen] Roll back an unnecessary change in Spillable.
ae538dc [Josh Rosen] Document UnsafeShuffleManager.
ec6d626 [Josh Rosen] Add notes on maximum # of supported shuffle partitions.
0d4d199 [Josh Rosen] Bump up shuffle.memoryFraction to make tests pass.
b3b1924 [Josh Rosen] Properly implement close() and flush() in DummySerializerInstance.
1ef56c7 [Josh Rosen] Revise compression codec support in merger; test cross product of configurations.
b57c17f [Josh Rosen] Disable some overly-verbose logs that rendered DEBUG useless.
f780fb1 [Josh Rosen] Add test demonstrating which compression codecs support concatenation.
4a01c45 [Josh Rosen] Remove unnecessary log message
27b18b0 [Josh Rosen] That for inserting records AT the max record size.
fcd9a3c [Josh Rosen] Add notes + tests for maximum record / page sizes.
9d1ee7c [Josh Rosen] Fix MiMa excludes for ShuffleWriter change
fd4bb9e [Josh Rosen] Use own ByteBufferOutputStream rather than Kryo's
67d25ba [Josh Rosen] Update Exchange operator's copying logic to account for new shuffle manager
8f5061a [Josh Rosen] Strengthen assertion to check partitioning
01afc74 [Josh Rosen] Actually read data in UnsafeShuffleWriterSuite
1929a74 [Josh Rosen] Update to reflect upstream ShuffleBlockManager -> ShuffleBlockResolver rename.
e8718dd [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort
9b7ebed [Josh Rosen] More defensive programming RE: cleaning up spill files and memory after errors
7cd013b [Josh Rosen] Begin refactoring to enable proper tests for spilling.
722849b [Josh Rosen] Add workaround for transferTo() bug in merging code; refactor tests.
9883e30 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort
b95e642 [Josh Rosen] Refactor and document logic that decides when to spill.
1ce1300 [Josh Rosen] More minor cleanup
5e8cf75 [Josh Rosen] More minor cleanup
e67f1ea [Josh Rosen] Remove upper type bound in ShuffleWriter interface.
cfe0ec4 [Josh Rosen] Address a number of minor review comments:
8a6fe52 [Josh Rosen] Rename UnsafeShuffleSpillWriter to UnsafeShuffleExternalSorter
11feeb6 [Josh Rosen] Update TODOs related to shuffle write metrics.
b674412 [Josh Rosen] Merge remote-tracking branch 'origin/master' into unsafe-sort
aaea17b [Josh Rosen] Add comments to UnsafeShuffleSpillWriter.
4f70141 [Josh Rosen] Fix merging; now passes UnsafeShuffleSuite tests.
133c8c9 [Josh Rosen] WIP towards testing UnsafeShuffleWriter.
f480fb2 [Josh Rosen] WIP in mega-refactoring towards shuffle-specific sort.
57f1ec0 [Josh Rosen] WIP towards packed record pointers for use in optimized shuffle sort.
69232fd [Josh Rosen] Enable compressible address encoding for off-heap mode.
7ee918e [Josh Rosen] Re-order imports in tests
3aeaff7 [Josh Rosen] More refactoring and cleanup; begin cleaning iterator interfaces
3490512 [Josh Rosen] Misc. cleanup
f156a8f [Josh Rosen] Hacky metrics integration; refactor some interfaces.
2776aca [Josh Rosen] First passing test for ExternalSorter.
5e100b2 [Josh Rosen] Super-messy WIP on external sort
595923a [Josh Rosen] Remove some unused variables.
8958584 [Josh Rosen] Fix bug in calculating free space in current page.
f17fa8f [Josh Rosen] Add missing newline
c2fca17 [Josh Rosen] Small refactoring of SerializerPropertiesSuite to enable test re-use:
b8a09fe [Josh Rosen] Back out accidental log4j.properties change
bfc12d3 [Josh Rosen] Add tests for serializer relocation property.
240864c [Josh Rosen] Remove PrefixComputer and require prefix to be specified as part of insert()
1433b42 [Josh Rosen] Store record length as int instead of long.
026b497 [Josh Rosen] Re-use a buffer in UnsafeShuffleWriter
0748458 [Josh Rosen] Port UnsafeShuffleWriter to Java.
87e721b [Josh Rosen] Renaming and comments
d3cc310 [Josh Rosen] Flag that SparkSqlSerializer2 supports relocation
e2d96ca [Josh Rosen] Expand serializer API and use new function to help control when new UnsafeShuffle path is used.
e267cee [Josh Rosen] Fix compilation of UnsafeSorterSuite
9c6cf58 [Josh Rosen] Refactor to use DiskBlockObjectWriter.
253f13e [Josh Rosen] More cleanup
8e3ec20 [Josh Rosen] Begin code cleanup.
4d2f5e1 [Josh Rosen] WIP
3db12de [Josh Rosen] Minor simplification and sanity checks in UnsafeSorter
767d3ca [Josh Rosen] Fix invalid range in UnsafeSorter.
e900152 [Josh Rosen] Add test for empty iterator in UnsafeSorter
57a4ea0 [Josh Rosen] Make initialSize configurable in UnsafeSorter
abf7bfe [Josh Rosen] Add basic test case.
81d52c5 [Josh Rosen] WIP on UnsafeSorter


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/73bed408
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/73bed408
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/73bed408

Branch: refs/heads/master
Commit: 73bed408fbb47dfc28063afa3898c27fbdec7735
Parents: 61d1e87
Author: Josh Rosen <jo...@databricks.com>
Authored: Wed May 13 17:07:31 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Wed May 13 17:07:31 2015 -0700

----------------------------------------------------------------------
 core/pom.xml                                    |  10 +
 .../shuffle/unsafe/DummySerializerInstance.java |  93 ++++
 .../shuffle/unsafe/PackedRecordPointer.java     |  92 ++++
 .../apache/spark/shuffle/unsafe/SpillInfo.java  |  37 ++
 .../unsafe/UnsafeShuffleExternalSorter.java     | 422 +++++++++++++++
 .../unsafe/UnsafeShuffleInMemorySorter.java     | 124 +++++
 .../unsafe/UnsafeShuffleSortDataFormat.java     |  67 +++
 .../shuffle/unsafe/UnsafeShuffleWriter.java     | 438 +++++++++++++++
 .../spark/storage/TimeTrackingOutputStream.java |  75 +++
 .../main/scala/org/apache/spark/SparkEnv.scala  |   3 +-
 .../spark/serializer/JavaSerializer.scala       |   2 +
 .../shuffle/FileShuffleBlockResolver.scala      |   2 +-
 .../apache/spark/shuffle/ShuffleWriter.scala    |   7 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   2 +-
 .../spark/shuffle/sort/SortShuffleManager.scala |   2 +-
 .../spark/shuffle/sort/SortShuffleWriter.scala  |   2 +-
 .../shuffle/unsafe/UnsafeShuffleManager.scala   | 205 ++++++++
 .../spark/storage/BlockObjectWriter.scala       |  24 +-
 .../util/collection/ExternalAppendOnlyMap.scala |   2 +-
 .../spark/util/collection/ExternalSorter.scala  |   2 +-
 .../unsafe/PackedRecordPointerSuite.java        | 101 ++++
 .../UnsafeShuffleInMemorySorterSuite.java       | 132 +++++
 .../unsafe/UnsafeShuffleWriterSuite.java        | 527 +++++++++++++++++++
 .../apache/spark/io/CompressionCodecSuite.scala |  44 ++
 .../spark/serializer/JavaSerializerSuite.scala  |  29 +
 .../unsafe/UnsafeShuffleManagerSuite.scala      | 128 +++++
 .../shuffle/unsafe/UnsafeShuffleSuite.scala     | 105 ++++
 pom.xml                                         |  14 +-
 project/MimaExcludes.scala                      |   6 +
 .../apache/spark/sql/execution/Exchange.scala   |  28 +-
 unsafe/pom.xml                                  |   4 +
 .../spark/unsafe/memory/TaskMemoryManager.java  |  79 ++-
 .../unsafe/memory/TaskMemoryManagerSuite.java   |  23 +
 33 files changed, 2767 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/pom.xml
----------------------------------------------------------------------
diff --git a/core/pom.xml b/core/pom.xml
index 262a332..bfa49d0 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -362,6 +362,16 @@
       <scope>test</scope>
     </dependency>
     <dependency>
+      <groupId>org.hamcrest</groupId>
+      <artifactId>hamcrest-core</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.hamcrest</groupId>
+      <artifactId>hamcrest-library</artifactId>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
       <groupId>com.novocode</groupId>
       <artifactId>junit-interface</artifactId>
       <scope>test</scope>

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
new file mode 100644
index 0000000..3f746b8
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/DummySerializerInstance.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.nio.ByteBuffer;
+
+import scala.reflect.ClassTag;
+
+import org.apache.spark.serializer.DeserializationStream;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.unsafe.PlatformDependent;
+
+/**
+ * Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+ * Our shuffle write path doesn't actually use this serializer (since we end up calling the
+ * `write() OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+ * around this, we pass a dummy no-op serializer.
+ */
+final class DummySerializerInstance extends SerializerInstance {
+
+  public static final DummySerializerInstance INSTANCE = new DummySerializerInstance();
+
+  private DummySerializerInstance() { }
+
+  @Override
+  public SerializationStream serializeStream(final OutputStream s) {
+    return new SerializationStream() {
+      @Override
+      public void flush() {
+        // Need to implement this because DiskObjectWriter uses it to flush the compression stream
+        try {
+          s.flush();
+        } catch (IOException e) {
+          PlatformDependent.throwException(e);
+        }
+      }
+
+      @Override
+      public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
+        throw new UnsupportedOperationException();
+      }
+
+      @Override
+      public void close() {
+        // Need to implement this because DiskObjectWriter uses it to close the compression stream
+        try {
+          s.close();
+        } catch (IOException e) {
+          PlatformDependent.throwException(e);
+        }
+      }
+    };
+  }
+
+  @Override
+  public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public DeserializationStream deserializeStream(InputStream s) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public <T> T deserialize(ByteBuffer bytes, ClassLoader loader, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) {
+    throw new UnsupportedOperationException();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
new file mode 100644
index 0000000..4ee6a82
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/PackedRecordPointer.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+/**
+ * Wrapper around an 8-byte word that holds a 24-bit partition number and 40-bit record pointer.
+ * <p>
+ * Within the long, the data is laid out as follows:
+ * <pre>
+ *   [24 bit partition number][13 bit memory page number][27 bit offset in page]
+ * </pre>
+ * This implies that the maximum addressable page size is 2^27 bits = 128 megabytes, assuming that
+ * our offsets in pages are not 8-byte-word-aligned. Since we have 2^13 pages (based off the
+ * 13-bit page numbers assigned by {@link org.apache.spark.unsafe.memory.TaskMemoryManager}), this
+ * implies that we can address 2^13 * 128 megabytes = 1 terabyte of RAM per task.
+ * <p>
+ * Assuming word-alignment would allow for a 1 gigabyte maximum page size, but we leave this
+ * optimization to future work as it will require more careful design to ensure that addresses are
+ * properly aligned (e.g. by padding records).
+ */
+final class PackedRecordPointer {
+
+  static final int MAXIMUM_PAGE_SIZE_BYTES = 1 << 27;  // 128 megabytes
+
+  /**
+   * The maximum partition identifier that can be encoded. Note that partition ids start from 0.
+   */
+  static final int MAXIMUM_PARTITION_ID = (1 << 24) - 1;  // 16777215
+
+  /** Bit mask for the lower 40 bits of a long. */
+  private static final long MASK_LONG_LOWER_40_BITS = (1L << 40) - 1;
+
+  /** Bit mask for the upper 24 bits of a long */
+  private static final long MASK_LONG_UPPER_24_BITS = ~MASK_LONG_LOWER_40_BITS;
+
+  /** Bit mask for the lower 27 bits of a long. */
+  private static final long MASK_LONG_LOWER_27_BITS = (1L << 27) - 1;
+
+  /** Bit mask for the lower 51 bits of a long. */
+  private static final long MASK_LONG_LOWER_51_BITS = (1L << 51) - 1;
+
+  /** Bit mask for the upper 13 bits of a long */
+  private static final long MASK_LONG_UPPER_13_BITS = ~MASK_LONG_LOWER_51_BITS;
+
+  /**
+   * Pack a record address and partition id into a single word.
+   *
+   * @param recordPointer a record pointer encoded by TaskMemoryManager.
+   * @param partitionId a shuffle partition id (maximum value of 2^24).
+   * @return a packed pointer that can be decoded using the {@link PackedRecordPointer} class.
+   */
+  public static long packPointer(long recordPointer, int partitionId) {
+    assert (partitionId <= MAXIMUM_PARTITION_ID);
+    // Note that without word alignment we can address 2^27 bytes = 128 megabytes per page.
+    // Also note that this relies on some internals of how TaskMemoryManager encodes its addresses.
+    final long pageNumber = (recordPointer & MASK_LONG_UPPER_13_BITS) >>> 24;
+    final long compressedAddress = pageNumber | (recordPointer & MASK_LONG_LOWER_27_BITS);
+    return (((long) partitionId) << 40) | compressedAddress;
+  }
+
+  private long packedRecordPointer;
+
+  public void set(long packedRecordPointer) {
+    this.packedRecordPointer = packedRecordPointer;
+  }
+
+  public int getPartitionId() {
+    return (int) ((packedRecordPointer & MASK_LONG_UPPER_24_BITS) >>> 40);
+  }
+
+  public long getRecordPointer() {
+    final long pageNumber = (packedRecordPointer << 24) & MASK_LONG_UPPER_13_BITS;
+    final long offsetInPage = packedRecordPointer & MASK_LONG_LOWER_27_BITS;
+    return pageNumber | offsetInPage;
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
new file mode 100644
index 0000000..7bac0dc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/SpillInfo.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+
+import org.apache.spark.storage.TempShuffleBlockId;
+
+/**
+ * Metadata for a block of data written by {@link UnsafeShuffleExternalSorter}.
+ */
+final class SpillInfo {
+  final long[] partitionLengths;
+  final File file;
+  final TempShuffleBlockId blockId;
+
+  public SpillInfo(int numPartitions, File file, TempShuffleBlockId blockId) {
+    this.partitionLengths = new long[numPartitions];
+    this.file = file;
+    this.blockId = blockId;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
new file mode 100644
index 0000000..9e9ed94
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java
@@ -0,0 +1,422 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.LinkedList;
+
+import scala.Tuple2;
+
+import com.google.common.annotations.VisibleForTesting;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.storage.*;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
+
+/**
+ * An external sorter that is specialized for sort-based shuffle.
+ * <p>
+ * Incoming records are appended to data pages. When all records have been inserted (or when the
+ * current thread's shuffle memory limit is reached), the in-memory records are sorted according to
+ * their partition ids (using a {@link UnsafeShuffleInMemorySorter}). The sorted records are then
+ * written to a single output file (or multiple files, if we've spilled). The format of the output
+ * files is the same as the format of the final output file written by
+ * {@link org.apache.spark.shuffle.sort.SortShuffleWriter}: each output partition's records are
+ * written as a single serialized, compressed stream that can be read with a new decompression and
+ * deserialization stream.
+ * <p>
+ * Unlike {@link org.apache.spark.util.collection.ExternalSorter}, this sorter does not merge its
+ * spill files. Instead, this merging is performed in {@link UnsafeShuffleWriter}, which uses a
+ * specialized merge procedure that avoids extra serialization/deserialization.
+ */
+final class UnsafeShuffleExternalSorter {
+
+  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
+
+  private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
+  @VisibleForTesting
+  static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
+  @VisibleForTesting
+  static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
+
+  private final int initialSize;
+  private final int numPartitions;
+  private final TaskMemoryManager memoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final BlockManager blockManager;
+  private final TaskContext taskContext;
+  private final ShuffleWriteMetrics writeMetrics;
+
+  /** The buffer size to use when writing spills using DiskBlockObjectWriter */
+  private final int fileBufferSizeBytes;
+
+  /**
+   * Memory pages that hold the records being sorted. The pages in this list are freed when
+   * spilling, although in principle we could recycle these pages across spills (on the other hand,
+   * this might not be necessary if we maintained a pool of re-usable pages in the TaskMemoryManager
+   * itself).
+   */
+  private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
+
+  private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
+
+  // These variables are reset after spilling:
+  private UnsafeShuffleInMemorySorter sorter;
+  private MemoryBlock currentPage = null;
+  private long currentPagePosition = -1;
+  private long freeSpaceInCurrentPage = 0;
+
+  public UnsafeShuffleExternalSorter(
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      BlockManager blockManager,
+      TaskContext taskContext,
+      int initialSize,
+      int numPartitions,
+      SparkConf conf,
+      ShuffleWriteMetrics writeMetrics) throws IOException {
+    this.memoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.blockManager = blockManager;
+    this.taskContext = taskContext;
+    this.initialSize = initialSize;
+    this.numPartitions = numPartitions;
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+    this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+
+    this.writeMetrics = writeMetrics;
+    initializeForWriting();
+  }
+
+  /**
+   * Allocates new sort data structures. Called when creating the sorter and after each spill.
+   */
+  private void initializeForWriting() throws IOException {
+    // TODO: move this sizing calculation logic into a static method of sorter:
+    final long memoryRequested = initialSize * 8L;
+    final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
+    if (memoryAcquired != memoryRequested) {
+      shuffleMemoryManager.release(memoryAcquired);
+      throw new IOException("Could not acquire " + memoryRequested + " bytes of memory");
+    }
+
+    this.sorter = new UnsafeShuffleInMemorySorter(initialSize);
+  }
+
+  /**
+   * Sorts the in-memory records and writes the sorted records to an on-disk file.
+   * This method does not free the sort data structures.
+   *
+   * @param isLastFile if true, this indicates that we're writing the final output file and that the
+   *                   bytes written should be counted towards shuffle spill metrics rather than
+   *                   shuffle write metrics.
+   */
+  private void writeSortedFile(boolean isLastFile) throws IOException {
+
+    final ShuffleWriteMetrics writeMetricsToUse;
+
+    if (isLastFile) {
+      // We're writing the final non-spill file, so we _do_ want to count this as shuffle bytes.
+      writeMetricsToUse = writeMetrics;
+    } else {
+      // We're spilling, so bytes written should be counted towards spill rather than write.
+      // Create a dummy WriteMetrics object to absorb these metrics, since we don't want to count
+      // them towards shuffle bytes written.
+      writeMetricsToUse = new ShuffleWriteMetrics();
+    }
+
+    // This call performs the actual sort.
+    final UnsafeShuffleInMemorySorter.UnsafeShuffleSorterIterator sortedRecords =
+      sorter.getSortedIterator();
+
+    // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this
+    // after SPARK-5581 is fixed.
+    BlockObjectWriter writer;
+
+    // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to
+    // be an API to directly transfer bytes from managed memory to the disk writer, we buffer
+    // data through a byte array. This array does not need to be large enough to hold a single
+    // record;
+    final byte[] writeBuffer = new byte[DISK_WRITE_BUFFER_SIZE];
+
+    // Because this output will be read during shuffle, its compression codec must be controlled by
+    // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+    // createTempShuffleBlock here; see SPARK-3426 for more details.
+    final Tuple2<TempShuffleBlockId, File> spilledFileInfo =
+      blockManager.diskBlockManager().createTempShuffleBlock();
+    final File file = spilledFileInfo._2();
+    final TempShuffleBlockId blockId = spilledFileInfo._1();
+    final SpillInfo spillInfo = new SpillInfo(numPartitions, file, blockId);
+
+    // Unfortunately, we need a serializer instance in order to construct a DiskBlockObjectWriter.
+    // Our write path doesn't actually use this serializer (since we end up calling the `write()`
+    // OutputStream methods), but DiskBlockObjectWriter still calls some methods on it. To work
+    // around this, we pass a dummy no-op serializer.
+    final SerializerInstance ser = DummySerializerInstance.INSTANCE;
+
+    writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+
+    int currentPartition = -1;
+    while (sortedRecords.hasNext()) {
+      sortedRecords.loadNext();
+      final int partition = sortedRecords.packedRecordPointer.getPartitionId();
+      assert (partition >= currentPartition);
+      if (partition != currentPartition) {
+        // Switch to the new partition
+        if (currentPartition != -1) {
+          writer.commitAndClose();
+          spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        }
+        currentPartition = partition;
+        writer =
+          blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse);
+      }
+
+      final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer();
+      final Object recordPage = memoryManager.getPage(recordPointer);
+      final long recordOffsetInPage = memoryManager.getOffsetInPage(recordPointer);
+      int dataRemaining = PlatformDependent.UNSAFE.getInt(recordPage, recordOffsetInPage);
+      long recordReadPosition = recordOffsetInPage + 4; // skip over record length
+      while (dataRemaining > 0) {
+        final int toTransfer = Math.min(DISK_WRITE_BUFFER_SIZE, dataRemaining);
+        PlatformDependent.copyMemory(
+          recordPage,
+          recordReadPosition,
+          writeBuffer,
+          PlatformDependent.BYTE_ARRAY_OFFSET,
+          toTransfer);
+        writer.write(writeBuffer, 0, toTransfer);
+        recordReadPosition += toTransfer;
+        dataRemaining -= toTransfer;
+      }
+      writer.recordWritten();
+    }
+
+    if (writer != null) {
+      writer.commitAndClose();
+      // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted,
+      // then the file might be empty. Note that it might be better to avoid calling
+      // writeSortedFile() in that case.
+      if (currentPartition != -1) {
+        spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length();
+        spills.add(spillInfo);
+      }
+    }
+
+    if (!isLastFile) {  // i.e. this is a spill file
+      // The current semantics of `shuffleRecordsWritten` seem to be that it's updated when records
+      // are written to disk, not when they enter the shuffle sorting code. DiskBlockObjectWriter
+      // relies on its `recordWritten()` method being called in order to trigger periodic updates to
+      // `shuffleBytesWritten`. If we were to remove the `recordWritten()` call and increment that
+      // counter at a higher-level, then the in-progress metrics for records written and bytes
+      // written would get out of sync.
+      //
+      // When writing the last file, we pass `writeMetrics` directly to the DiskBlockObjectWriter;
+      // in all other cases, we pass in a dummy write metrics to capture metrics, then copy those
+      // metrics to the true write metrics here. The reason for performing this copying is so that
+      // we can avoid reporting spilled bytes as shuffle write bytes.
+      //
+      // Note that we intentionally ignore the value of `writeMetricsToUse.shuffleWriteTime()`.
+      // Consistent with ExternalSorter, we do not count this IO towards shuffle write time.
+      // This means that this IO time is not accounted for anywhere; SPARK-3577 will fix this.
+      writeMetrics.incShuffleRecordsWritten(writeMetricsToUse.shuffleRecordsWritten());
+      taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.shuffleBytesWritten());
+    }
+  }
+
+  /**
+   * Sort and spill the current records in response to memory pressure.
+   */
+  @VisibleForTesting
+  void spill() throws IOException {
+    logger.info("Thread {} spilling sort data of {} to disk ({} {} so far)",
+      Thread.currentThread().getId(),
+      Utils.bytesToString(getMemoryUsage()),
+      spills.size(),
+      spills.size() > 1 ? " times" : " time");
+
+    writeSortedFile(false);
+    final long sorterMemoryUsage = sorter.getMemoryUsage();
+    sorter = null;
+    shuffleMemoryManager.release(sorterMemoryUsage);
+    final long spillSize = freeMemory();
+    taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
+
+    initializeForWriting();
+  }
+
+  private long getMemoryUsage() {
+    return sorter.getMemoryUsage() + (allocatedPages.size() * (long) PAGE_SIZE);
+  }
+
+  private long freeMemory() {
+    long memoryFreed = 0;
+    for (MemoryBlock block : allocatedPages) {
+      memoryManager.freePage(block);
+      shuffleMemoryManager.release(block.size());
+      memoryFreed += block.size();
+    }
+    allocatedPages.clear();
+    currentPage = null;
+    currentPagePosition = -1;
+    freeSpaceInCurrentPage = 0;
+    return memoryFreed;
+  }
+
+  /**
+   * Force all memory and spill files to be deleted; called by shuffle error-handling code.
+   */
+  public void cleanupAfterError() {
+    freeMemory();
+    for (SpillInfo spill : spills) {
+      if (spill.file.exists() && !spill.file.delete()) {
+        logger.error("Unable to delete spill file {}", spill.file.getPath());
+      }
+    }
+    if (sorter != null) {
+      shuffleMemoryManager.release(sorter.getMemoryUsage());
+      sorter = null;
+    }
+  }
+
+  /**
+   * Checks whether there is enough space to insert a new record into the sorter.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size.
+
+   * @return true if the record can be inserted without requiring more allocations, false otherwise.
+   */
+  private boolean haveSpaceForRecord(int requiredSpace) {
+    assert (requiredSpace > 0);
+    return (sorter.hasSpaceForAnotherRecord() && (requiredSpace <= freeSpaceInCurrentPage));
+  }
+
+  /**
+   * Allocates more memory in order to insert an additional record. This will request additional
+   * memory from the {@link ShuffleMemoryManager} and spill if the requested memory can not be
+   * obtained.
+   *
+   * @param requiredSpace the required space in the data page, in bytes, including space for storing
+   *                      the record size.
+   */
+  private void allocateSpaceForRecord(int requiredSpace) throws IOException {
+    if (!sorter.hasSpaceForAnotherRecord()) {
+      logger.debug("Attempting to expand sort pointer array");
+      final long oldPointerArrayMemoryUsage = sorter.getMemoryUsage();
+      final long memoryToGrowPointerArray = oldPointerArrayMemoryUsage * 2;
+      final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowPointerArray);
+      if (memoryAcquired < memoryToGrowPointerArray) {
+        shuffleMemoryManager.release(memoryAcquired);
+        spill();
+      } else {
+        sorter.expandPointerArray();
+        shuffleMemoryManager.release(oldPointerArrayMemoryUsage);
+      }
+    }
+    if (requiredSpace > freeSpaceInCurrentPage) {
+      logger.trace("Required space {} is less than free space in current page ({})", requiredSpace,
+        freeSpaceInCurrentPage);
+      // TODO: we should track metrics on the amount of space wasted when we roll over to a new page
+      // without using the free space at the end of the current page. We should also do this for
+      // BytesToBytesMap.
+      if (requiredSpace > PAGE_SIZE) {
+        throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
+          PAGE_SIZE + ")");
+      } else {
+        final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+        if (memoryAcquired < PAGE_SIZE) {
+          shuffleMemoryManager.release(memoryAcquired);
+          spill();
+          final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
+          if (memoryAcquiredAfterSpilling != PAGE_SIZE) {
+            shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+            throw new IOException("Unable to acquire " + PAGE_SIZE + " bytes of memory");
+          }
+        }
+        currentPage = memoryManager.allocatePage(PAGE_SIZE);
+        currentPagePosition = currentPage.getBaseOffset();
+        freeSpaceInCurrentPage = PAGE_SIZE;
+        allocatedPages.add(currentPage);
+      }
+    }
+  }
+
+  /**
+   * Write a record to the shuffle sorter.
+   */
+  public void insertRecord(
+      Object recordBaseObject,
+      long recordBaseOffset,
+      int lengthInBytes,
+      int partitionId) throws IOException {
+    // Need 4 bytes to store the record length.
+    final int totalSpaceRequired = lengthInBytes + 4;
+    if (!haveSpaceForRecord(totalSpaceRequired)) {
+      allocateSpaceForRecord(totalSpaceRequired);
+    }
+
+    final long recordAddress =
+      memoryManager.encodePageNumberAndOffset(currentPage, currentPagePosition);
+    final Object dataPageBaseObject = currentPage.getBaseObject();
+    PlatformDependent.UNSAFE.putInt(dataPageBaseObject, currentPagePosition, lengthInBytes);
+    currentPagePosition += 4;
+    freeSpaceInCurrentPage -= 4;
+    PlatformDependent.copyMemory(
+      recordBaseObject,
+      recordBaseOffset,
+      dataPageBaseObject,
+      currentPagePosition,
+      lengthInBytes);
+    currentPagePosition += lengthInBytes;
+    freeSpaceInCurrentPage -= lengthInBytes;
+    sorter.insertRecord(recordAddress, partitionId);
+  }
+
+  /**
+   * Close the sorter, causing any buffered data to be sorted and written out to disk.
+   *
+   * @return metadata for the spill files written by this sorter. If no records were ever inserted
+   *         into this sorter, then this will return an empty array.
+   * @throws IOException
+   */
+  public SpillInfo[] closeAndGetSpills() throws IOException {
+    try {
+      if (sorter != null) {
+        // Do not count the final file towards the spill count.
+        writeSortedFile(true);
+        freeMemory();
+      }
+      return spills.toArray(new SpillInfo[spills.size()]);
+    } catch (IOException e) {
+      cleanupAfterError();
+      throw e;
+    }
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
new file mode 100644
index 0000000..5bab501
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleInMemorySorter.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.util.Comparator;
+
+import org.apache.spark.util.collection.Sorter;
+
+final class UnsafeShuffleInMemorySorter {
+
+  private final Sorter<PackedRecordPointer, long[]> sorter;
+  private static final class SortComparator implements Comparator<PackedRecordPointer> {
+    @Override
+    public int compare(PackedRecordPointer left, PackedRecordPointer right) {
+      return left.getPartitionId() - right.getPartitionId();
+    }
+  }
+  private static final SortComparator SORT_COMPARATOR = new SortComparator();
+
+  /**
+   * An array of record pointers and partition ids that have been encoded by
+   * {@link PackedRecordPointer}. The sort operates on this array instead of directly manipulating
+   * records.
+   */
+  private long[] pointerArray;
+
+  /**
+   * The position in the pointer array where new records can be inserted.
+   */
+  private int pointerArrayInsertPosition = 0;
+
+  public UnsafeShuffleInMemorySorter(int initialSize) {
+    assert (initialSize > 0);
+    this.pointerArray = new long[initialSize];
+    this.sorter = new Sorter<PackedRecordPointer, long[]>(UnsafeShuffleSortDataFormat.INSTANCE);
+  }
+
+  public void expandPointerArray() {
+    final long[] oldArray = pointerArray;
+    // Guard against overflow:
+    final int newLength = oldArray.length * 2 > 0 ? (oldArray.length * 2) : Integer.MAX_VALUE;
+    pointerArray = new long[newLength];
+    System.arraycopy(oldArray, 0, pointerArray, 0, oldArray.length);
+  }
+
+  public boolean hasSpaceForAnotherRecord() {
+    return pointerArrayInsertPosition + 1 < pointerArray.length;
+  }
+
+  public long getMemoryUsage() {
+    return pointerArray.length * 8L;
+  }
+
+  /**
+   * Inserts a record to be sorted.
+   *
+   * @param recordPointer a pointer to the record, encoded by the task memory manager. Due to
+   *                      certain pointer compression techniques used by the sorter, the sort can
+   *                      only operate on pointers that point to locations in the first
+   *                      {@link PackedRecordPointer#MAXIMUM_PAGE_SIZE_BYTES} bytes of a data page.
+   * @param partitionId the partition id, which must be less than or equal to
+   *                    {@link PackedRecordPointer#MAXIMUM_PARTITION_ID}.
+   */
+  public void insertRecord(long recordPointer, int partitionId) {
+    if (!hasSpaceForAnotherRecord()) {
+      if (pointerArray.length == Integer.MAX_VALUE) {
+        throw new IllegalStateException("Sort pointer array has reached maximum size");
+      } else {
+        expandPointerArray();
+      }
+    }
+    pointerArray[pointerArrayInsertPosition] =
+        PackedRecordPointer.packPointer(recordPointer, partitionId);
+    pointerArrayInsertPosition++;
+  }
+
+  /**
+   * An iterator-like class that's used instead of Java's Iterator in order to facilitate inlining.
+   */
+  public static final class UnsafeShuffleSorterIterator {
+
+    private final long[] pointerArray;
+    private final int numRecords;
+    final PackedRecordPointer packedRecordPointer = new PackedRecordPointer();
+    private int position = 0;
+
+    public UnsafeShuffleSorterIterator(int numRecords, long[] pointerArray) {
+      this.numRecords = numRecords;
+      this.pointerArray = pointerArray;
+    }
+
+    public boolean hasNext() {
+      return position < numRecords;
+    }
+
+    public void loadNext() {
+      packedRecordPointer.set(pointerArray[position]);
+      position++;
+    }
+  }
+
+  /**
+   * Return an iterator over record pointers in sorted order.
+   */
+  public UnsafeShuffleSorterIterator getSortedIterator() {
+    sorter.sort(pointerArray, 0, pointerArrayInsertPosition, SORT_COMPARATOR);
+    return new UnsafeShuffleSorterIterator(pointerArrayInsertPosition, pointerArray);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
new file mode 100644
index 0000000..a66d74e
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSortDataFormat.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import org.apache.spark.util.collection.SortDataFormat;
+
+final class UnsafeShuffleSortDataFormat extends SortDataFormat<PackedRecordPointer, long[]> {
+
+  public static final UnsafeShuffleSortDataFormat INSTANCE = new UnsafeShuffleSortDataFormat();
+
+  private UnsafeShuffleSortDataFormat() { }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos) {
+    // Since we re-use keys, this method shouldn't be called.
+    throw new UnsupportedOperationException();
+  }
+
+  @Override
+  public PackedRecordPointer newKey() {
+    return new PackedRecordPointer();
+  }
+
+  @Override
+  public PackedRecordPointer getKey(long[] data, int pos, PackedRecordPointer reuse) {
+    reuse.set(data[pos]);
+    return reuse;
+  }
+
+  @Override
+  public void swap(long[] data, int pos0, int pos1) {
+    final long temp = data[pos0];
+    data[pos0] = data[pos1];
+    data[pos1] = temp;
+  }
+
+  @Override
+  public void copyElement(long[] src, int srcPos, long[] dst, int dstPos) {
+    dst[dstPos] = src[srcPos];
+  }
+
+  @Override
+  public void copyRange(long[] src, int srcPos, long[] dst, int dstPos, int length) {
+    System.arraycopy(src, srcPos, dst, dstPos, length);
+  }
+
+  @Override
+  public long[] allocate(int length) {
+    return new long[length];
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
new file mode 100644
index 0000000..ad7eb04
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java
@@ -0,0 +1,438 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.unsafe;
+
+import java.io.*;
+import java.nio.channels.FileChannel;
+import java.util.Iterator;
+import javax.annotation.Nullable;
+
+import scala.Option;
+import scala.Product2;
+import scala.collection.JavaConversions;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
+
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.io.ByteStreams;
+import com.google.common.io.Closeables;
+import com.google.common.io.Files;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.*;
+import org.apache.spark.annotation.Private;
+import org.apache.spark.io.CompressionCodec;
+import org.apache.spark.io.CompressionCodec$;
+import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.network.util.LimitedInputStream;
+import org.apache.spark.scheduler.MapStatus;
+import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.serializer.SerializationStream;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.IndexShuffleBlockResolver;
+import org.apache.spark.shuffle.ShuffleMemoryManager;
+import org.apache.spark.shuffle.ShuffleWriter;
+import org.apache.spark.storage.BlockManager;
+import org.apache.spark.storage.TimeTrackingOutputStream;
+import org.apache.spark.unsafe.PlatformDependent;
+import org.apache.spark.unsafe.memory.TaskMemoryManager;
+
+@Private
+public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
+
+  private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleWriter.class);
+
+  private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
+
+  @VisibleForTesting
+  static final int INITIAL_SORT_BUFFER_SIZE = 4096;
+
+  private final BlockManager blockManager;
+  private final IndexShuffleBlockResolver shuffleBlockResolver;
+  private final TaskMemoryManager memoryManager;
+  private final ShuffleMemoryManager shuffleMemoryManager;
+  private final SerializerInstance serializer;
+  private final Partitioner partitioner;
+  private final ShuffleWriteMetrics writeMetrics;
+  private final int shuffleId;
+  private final int mapId;
+  private final TaskContext taskContext;
+  private final SparkConf sparkConf;
+  private final boolean transferToEnabled;
+
+  private MapStatus mapStatus = null;
+  private UnsafeShuffleExternalSorter sorter = null;
+
+  /** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
+  private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
+    public MyByteArrayOutputStream(int size) { super(size); }
+    public byte[] getBuf() { return buf; }
+  }
+
+  private MyByteArrayOutputStream serBuffer;
+  private SerializationStream serOutputStream;
+
+  /**
+   * Are we in the process of stopping? Because map tasks can call stop() with success = true
+   * and then call stop() with success = false if they get an exception, we want to make sure
+   * we don't try deleting files, etc twice.
+   */
+  private boolean stopping = false;
+
+  public UnsafeShuffleWriter(
+      BlockManager blockManager,
+      IndexShuffleBlockResolver shuffleBlockResolver,
+      TaskMemoryManager memoryManager,
+      ShuffleMemoryManager shuffleMemoryManager,
+      UnsafeShuffleHandle<K, V> handle,
+      int mapId,
+      TaskContext taskContext,
+      SparkConf sparkConf) throws IOException {
+    final int numPartitions = handle.dependency().partitioner().numPartitions();
+    if (numPartitions > UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS()) {
+      throw new IllegalArgumentException(
+        "UnsafeShuffleWriter can only be used for shuffles with at most " +
+          UnsafeShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS() + " reduce partitions");
+    }
+    this.blockManager = blockManager;
+    this.shuffleBlockResolver = shuffleBlockResolver;
+    this.memoryManager = memoryManager;
+    this.shuffleMemoryManager = shuffleMemoryManager;
+    this.mapId = mapId;
+    final ShuffleDependency<K, V, V> dep = handle.dependency();
+    this.shuffleId = dep.shuffleId();
+    this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+    this.partitioner = dep.partitioner();
+    this.writeMetrics = new ShuffleWriteMetrics();
+    taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
+    this.taskContext = taskContext;
+    this.sparkConf = sparkConf;
+    this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true);
+    open();
+  }
+
+  /**
+   * This convenience method should only be called in test code.
+   */
+  @VisibleForTesting
+  public void write(Iterator<Product2<K, V>> records) throws IOException {
+    write(JavaConversions.asScalaIterator(records));
+  }
+
+  @Override
+  public void write(scala.collection.Iterator<Product2<K, V>> records) throws IOException {
+    boolean success = false;
+    try {
+      while (records.hasNext()) {
+        insertRecordIntoSorter(records.next());
+      }
+      closeAndWriteOutput();
+      success = true;
+    } finally {
+      if (!success) {
+        sorter.cleanupAfterError();
+      }
+    }
+  }
+
+  private void open() throws IOException {
+    assert (sorter == null);
+    sorter = new UnsafeShuffleExternalSorter(
+      memoryManager,
+      shuffleMemoryManager,
+      blockManager,
+      taskContext,
+      INITIAL_SORT_BUFFER_SIZE,
+      partitioner.numPartitions(),
+      sparkConf,
+      writeMetrics);
+    serBuffer = new MyByteArrayOutputStream(1024 * 1024);
+    serOutputStream = serializer.serializeStream(serBuffer);
+  }
+
+  @VisibleForTesting
+  void closeAndWriteOutput() throws IOException {
+    serBuffer = null;
+    serOutputStream = null;
+    final SpillInfo[] spills = sorter.closeAndGetSpills();
+    sorter = null;
+    final long[] partitionLengths;
+    try {
+      partitionLengths = mergeSpills(spills);
+    } finally {
+      for (SpillInfo spill : spills) {
+        if (spill.file.exists() && ! spill.file.delete()) {
+          logger.error("Error while deleting spill file {}", spill.file.getPath());
+        }
+      }
+    }
+    shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+    mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
+  }
+
+  @VisibleForTesting
+  void insertRecordIntoSorter(Product2<K, V> record) throws IOException {
+    final K key = record._1();
+    final int partitionId = partitioner.getPartition(key);
+    serBuffer.reset();
+    serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
+    serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
+    serOutputStream.flush();
+
+    final int serializedRecordSize = serBuffer.size();
+    assert (serializedRecordSize > 0);
+
+    sorter.insertRecord(
+      serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
+  }
+
+  @VisibleForTesting
+  void forceSorterToSpill() throws IOException {
+    assert (sorter != null);
+    sorter.spill();
+  }
+
+  /**
+   * Merge zero or more spill files together, choosing the fastest merging strategy based on the
+   * number of spills and the IO compression codec.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpills(SpillInfo[] spills) throws IOException {
+    final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+    final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
+    final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
+    final boolean fastMergeEnabled =
+      sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
+    final boolean fastMergeIsSupported =
+      !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+    try {
+      if (spills.length == 0) {
+        new FileOutputStream(outputFile).close(); // Create an empty file
+        return new long[partitioner.numPartitions()];
+      } else if (spills.length == 1) {
+        // Here, we don't need to perform any metrics updates because the bytes written to this
+        // output file would have already been counted as shuffle bytes written.
+        Files.move(spills[0].file, outputFile);
+        return spills[0].partitionLengths;
+      } else {
+        final long[] partitionLengths;
+        // There are multiple spills to merge, so none of these spill files' lengths were counted
+        // towards our shuffle write count or shuffle write time. If we use the slow merge path,
+        // then the final output file's size won't necessarily be equal to the sum of the spill
+        // files' sizes. To guard against this case, we look at the output file's actual size when
+        // computing shuffle bytes written.
+        //
+        // We allow the individual merge methods to report their own IO times since different merge
+        // strategies use different IO techniques.  We count IO during merge towards the shuffle
+        // shuffle write time, which appears to be consistent with the "not bypassing merge-sort"
+        // branch in ExternalSorter.
+        if (fastMergeEnabled && fastMergeIsSupported) {
+          // Compression is disabled or we are using an IO compression codec that supports
+          // decompression of concatenated compressed streams, so we can perform a fast spill merge
+          // that doesn't need to interpret the spilled bytes.
+          if (transferToEnabled) {
+            logger.debug("Using transferTo-based fast merge");
+            partitionLengths = mergeSpillsWithTransferTo(spills, outputFile);
+          } else {
+            logger.debug("Using fileStream-based fast merge");
+            partitionLengths = mergeSpillsWithFileStream(spills, outputFile, null);
+          }
+        } else {
+          logger.debug("Using slow merge");
+          partitionLengths = mergeSpillsWithFileStream(spills, outputFile, compressionCodec);
+        }
+        // When closing an UnsafeShuffleExternalSorter that has already spilled once but also has
+        // in-memory records, we write out the in-memory records to a file but do not count that
+        // final write as bytes spilled (instead, it's accounted as shuffle write). The merge needs
+        // to be counted as shuffle write, but this will lead to double-counting of the final
+        // SpillInfo's bytes.
+        writeMetrics.decShuffleBytesWritten(spills[spills.length - 1].file.length());
+        writeMetrics.incShuffleBytesWritten(outputFile.length());
+        return partitionLengths;
+      }
+    } catch (IOException e) {
+      if (outputFile.exists() && !outputFile.delete()) {
+        logger.error("Unable to delete output file {}", outputFile.getPath());
+      }
+      throw e;
+    }
+  }
+
+  /**
+   * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge,
+   * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in
+   * cases where the IO compression codec does not support concatenation of compressed data, or in
+   * cases where users have explicitly disabled use of {@code transferTo} in order to work around
+   * kernel bugs.
+   *
+   * @param spills the spills to merge.
+   * @param outputFile the file to write the merged data to.
+   * @param compressionCodec the IO compression codec, or null if shuffle compression is disabled.
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithFileStream(
+      SpillInfo[] spills,
+      File outputFile,
+      @Nullable CompressionCodec compressionCodec) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final InputStream[] spillInputStreams = new FileInputStream[spills.length];
+    OutputStream mergedFileOutputStream = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputStreams[i] = new FileInputStream(spills[i].file);
+      }
+      for (int partition = 0; partition < numPartitions; partition++) {
+        final long initialFileLength = outputFile.length();
+        mergedFileOutputStream =
+          new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true));
+        if (compressionCodec != null) {
+          mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream);
+        }
+
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          if (partitionLengthInSpill > 0) {
+            InputStream partitionInputStream =
+              new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
+            if (compressionCodec != null) {
+              partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+            }
+            ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+          }
+        }
+        mergedFileOutputStream.flush();
+        mergedFileOutputStream.close();
+        partitionLengths[partition] = (outputFile.length() - initialFileLength);
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (InputStream stream : spillInputStreams) {
+        Closeables.close(stream, threwException);
+      }
+      Closeables.close(mergedFileOutputStream, threwException);
+    }
+    return partitionLengths;
+  }
+
+  /**
+   * Merges spill files by using NIO's transferTo to concatenate spill partitions' bytes.
+   * This is only safe when the IO compression codec and serializer support concatenation of
+   * serialized streams.
+   *
+   * @return the partition lengths in the merged file.
+   */
+  private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) throws IOException {
+    assert (spills.length >= 2);
+    final int numPartitions = partitioner.numPartitions();
+    final long[] partitionLengths = new long[numPartitions];
+    final FileChannel[] spillInputChannels = new FileChannel[spills.length];
+    final long[] spillInputChannelPositions = new long[spills.length];
+    FileChannel mergedFileOutputChannel = null;
+
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < spills.length; i++) {
+        spillInputChannels[i] = new FileInputStream(spills[i].file).getChannel();
+      }
+      // This file needs to opened in append mode in order to work around a Linux kernel bug that
+      // affects transferTo; see SPARK-3948 for more details.
+      mergedFileOutputChannel = new FileOutputStream(outputFile, true).getChannel();
+
+      long bytesWrittenToMergedFile = 0;
+      for (int partition = 0; partition < numPartitions; partition++) {
+        for (int i = 0; i < spills.length; i++) {
+          final long partitionLengthInSpill = spills[i].partitionLengths[partition];
+          long bytesToTransfer = partitionLengthInSpill;
+          final FileChannel spillInputChannel = spillInputChannels[i];
+          final long writeStartTime = System.nanoTime();
+          while (bytesToTransfer > 0) {
+            final long actualBytesTransferred = spillInputChannel.transferTo(
+              spillInputChannelPositions[i],
+              bytesToTransfer,
+              mergedFileOutputChannel);
+            spillInputChannelPositions[i] += actualBytesTransferred;
+            bytesToTransfer -= actualBytesTransferred;
+          }
+          writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+          bytesWrittenToMergedFile += partitionLengthInSpill;
+          partitionLengths[partition] += partitionLengthInSpill;
+        }
+      }
+      // Check the position after transferTo loop to see if it is in the right position and raise an
+      // exception if it is incorrect. The position will not be increased to the expected length
+      // after calling transferTo in kernel version 2.6.32. This issue is described at
+      // https://bugs.openjdk.java.net/browse/JDK-7052359 and SPARK-3948.
+      if (mergedFileOutputChannel.position() != bytesWrittenToMergedFile) {
+        throw new IOException(
+          "Current position " + mergedFileOutputChannel.position() + " does not equal expected " +
+            "position " + bytesWrittenToMergedFile + " after transferTo. Please check your kernel" +
+            " version to see if it is 2.6.32, as there is a kernel bug which will lead to " +
+            "unexpected behavior when using transferTo. You can set spark.file.transferTo=false " +
+            "to disable this NIO feature."
+        );
+      }
+      threwException = false;
+    } finally {
+      // To avoid masking exceptions that caused us to prematurely enter the finally block, only
+      // throw exceptions during cleanup if threwException == false.
+      for (int i = 0; i < spills.length; i++) {
+        assert(spillInputChannelPositions[i] == spills[i].file.length());
+        Closeables.close(spillInputChannels[i], threwException);
+      }
+      Closeables.close(mergedFileOutputChannel, threwException);
+    }
+    return partitionLengths;
+  }
+
+  @Override
+  public Option<MapStatus> stop(boolean success) {
+    try {
+      if (stopping) {
+        return Option.apply(null);
+      } else {
+        stopping = true;
+        if (success) {
+          if (mapStatus == null) {
+            throw new IllegalStateException("Cannot call stop(true) without having called write()");
+          }
+          return Option.apply(mapStatus);
+        } else {
+          // The map task failed, so delete our output data.
+          shuffleBlockResolver.removeDataByMap(shuffleId, mapId);
+          return Option.apply(null);
+        }
+      }
+    } finally {
+      if (sorter != null) {
+        // If sorter is non-null, then this implies that we called stop() in response to an error,
+        // so we need to clean up memory and spill files created by the sorter
+        sorter.cleanupAfterError();
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
new file mode 100644
index 0000000..dc2aa30
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage;
+
+import java.io.IOException;
+import java.io.OutputStream;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+
+/**
+ * Intercepts write calls and tracks total time spent writing in order to update shuffle write
+ * metrics. Not thread safe.
+ */
+@Private
+public final class TimeTrackingOutputStream extends OutputStream {
+
+  private final ShuffleWriteMetrics writeMetrics;
+  private final OutputStream outputStream;
+
+  public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, OutputStream outputStream) {
+    this.writeMetrics = writeMetrics;
+    this.outputStream = outputStream;
+  }
+
+  @Override
+  public void write(int b) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void write(byte[] b) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void write(byte[] b, int off, int len) throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.write(b, off, len);
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void flush() throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.flush();
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+
+  @Override
+  public void close() throws IOException {
+    final long startTime = System.nanoTime();
+    outputStream.close();
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - startTime);
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/SparkEnv.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0c4d28f..a5d831c 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -313,7 +313,8 @@ object SparkEnv extends Logging {
     // Let the user specify short names for shuffle managers
     val shortShuffleMgrNames = Map(
       "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager",
-      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager")
+      "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager",
+      "tungsten-sort" -> "org.apache.spark.shuffle.unsafe.UnsafeShuffleManager")
     val shuffleMgrName = conf.get("spark.shuffle.manager", "sort")
     val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName)
     val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass)

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index dfbde7c..698d138 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -121,6 +121,8 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
   private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
   private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
 
+  protected def this() = this(new SparkConf())  // For deserialization only
+
   override def newInstance(): SerializerInstance = {
     val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
     new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index 6ad427b..6c3b308 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -76,7 +76,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
   private val consolidateShuffleFiles =
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
-  // Use getSizeAsKb (not bytes) to maintain backwards compatibility of on units are provided 
+  // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index f6e6fe5..4cc4ef5 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -17,14 +17,17 @@
 
 package org.apache.spark.shuffle
 
+import java.io.IOException
+
 import org.apache.spark.scheduler.MapStatus
 
 /**
  * Obtained inside a map task to write out records to the shuffle system.
  */
-private[spark] trait ShuffleWriter[K, V] {
+private[spark] abstract class ShuffleWriter[K, V] {
   /** Write a sequence of records to this task's output */
-  def write(records: Iterator[_ <: Product2[K, V]]): Unit
+  @throws[IOException]
+  def write(records: Iterator[Product2[K, V]]): Unit
 
   /** Close this writer, passing along whether the map completed */
   def stop(success: Boolean): Option[MapStatus]

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 897f0a5..eb87cee 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -49,7 +49,7 @@ private[spark] class HashShuffleWriter[K, V](
     writeMetrics)
 
   /** Write a bunch of records to this task's output */
-  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def write(records: Iterator[Product2[K, V]]): Unit = {
     val iter = if (dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {
         dep.aggregator.get.combineValuesByKey(records, context)

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 1584294..d7fab35 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -72,7 +72,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
     true
   }
 
-  override def shuffleBlockResolver: IndexShuffleBlockResolver = {
+  override val shuffleBlockResolver: IndexShuffleBlockResolver = {
     indexShuffleBlockResolver
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/73bed408/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index add2656..c9dd6bf 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -48,7 +48,7 @@ private[spark] class SortShuffleWriter[K, V, C](
   context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
 
   /** Write a bunch of records to this task's output */
-  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def write(records: Iterator[Product2[K, V]]): Unit = {
     if (dep.mapSideCombine) {
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
       sorter = new ExternalSorter[K, V, C](


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org