You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2015/05/31 00:27:55 UTC

spark git commit: [SPARK-7855] Move bypassMergeSort-handling from ExternalSorter to own component

Repository: spark
Updated Branches:
  refs/heads/master 7716a5a1e -> a6430028e


[SPARK-7855] Move bypassMergeSort-handling from ExternalSorter to own component

Spark's `ExternalSorter` writes shuffle output files during sort-based shuffle. Sort-shuffle contains a configuration, `spark.shuffle.sort.bypassMergeThreshold`, which causes ExternalSorter to skip sorting and merging and simply write separate files per partition, which are then concatenated together to form the final map output file.

The code paths used during this bypass are almost completely separate from ExternalSorter's other code paths, so refactoring them into a separate file can significantly simplify the code.

In addition to re-arranging code, this patch deletes a bunch of dead code.  The main entry point into ExternalSorter is `insertAll()` and in SPARK-4479 / #3422 this method was modified to completely bypass in-memory buffering of records when `bypassMergeSort` takes effect. As a result, some of the spilling and merging code paths will no longer be called when `bypassMergeSort` is used, so we should be able to safely remove that code.

There's an open JIRA ([SPARK-6026](https://issues.apache.org/jira/browse/SPARK-6026)) for removing the `bypassMergeThreshold` parameter and code paths; I have not done that here, but the changes in this patch will make removing that parameter significantly easier if we ever decide to do that.

This patch also makes several improvements to shuffle-related tests and adds more defensive checks to certain shuffle classes:

- DiskBlockObjectWriter now throws an exception if `fileSegment()` is called before `commitAndClose()` has been called.
- DiskBlockObjectWriter's close methods are now idempotent, so calling any of the close methods twice in a row will no longer result in incorrect shuffle write metrics changes.  Calling `revertPartialWritesAndClose()` on a closed DiskBlockObjectWriter now has no effect (before, it might mess up the metrics).
- The end-to-end shuffle record count metrics tests have been moved from InputOutputMetricsSuite to ShuffleSuite.  This means that these tests will now be run against all shuffle implementations rather than just the default shuffle configuration.
- The end-to-end metrics tests now include a test of a job which performs aggregation in the shuffle.
- Our tests now check that `shuffleBytesWritten == totalShuffleBytesRead`.
- FileSegment now throws IllegalArgumentException if it is constructed with a negative length or offset.

Author: Josh Rosen <jo...@databricks.com>

Closes #6397 from JoshRosen/external-sorter-bypass-cleanup and squashes the following commits:

bf3f3f6 [Josh Rosen] Merge remote-tracking branch 'origin/master' into external-sorter-bypass-cleanup
8b216c4 [Josh Rosen] Guard against negative offsets and lengths in FileSegment
03f35a4 [Josh Rosen] Minor fix to cleanup logic.
b5cc35b [Josh Rosen] Move shuffle metrics tests to ShuffleSuite.
8b8fb9e [Josh Rosen] Add more tests + defensive programming to DiskBlockObjectWriter.
16564eb [Josh Rosen] Guard against calling fileSegment() before commitAndClose() has been called.
96811b4 [Josh Rosen] Remove confusing taskMetrics.shuffleWriteMetrics() optional call
8522b6a [Josh Rosen] Do not perform a map-side sort unless we're also doing map-side aggregation
08e40f3 [Josh Rosen] Remove excessively clever (and wrong) implementation of newBuffer()
d7f9938 [Josh Rosen] Add missing overrides; fix compilation
71d76ff [Josh Rosen] Update Javadoc
bf0d98f [Josh Rosen] Add comment to clarify confusing factory code
5197f73 [Josh Rosen] Add missing private[this]
30ef2c8 [Josh Rosen] Convert BypassMergeSortShuffleWriter to Java
bc1a820 [Josh Rosen] Fix bug when aggregator is used but map-side combine is disabled
0d3dcc0 [Josh Rosen] Remove unnecessary overloaded methods
25b964f [Josh Rosen] Rename SortShuffleSorter to SortShuffleFileWriter
0d9848c [Josh Rosen] Make it more clear that curWriteMetrics is now only used for spill metrics
7af7aea [Josh Rosen] Combine spill() and spillToMergeableFile()
6320112 [Josh Rosen] Add missing negation in deletion success check.
d267e0d [Josh Rosen] Fix style issue
7f15f7b [Josh Rosen] Back out extra cleanup-handling code, since this is already covered in stop()
25aa3bd [Josh Rosen] Make sure to delete outputFile after errors.
931ca68 [Josh Rosen] Refactor tests.
6a35716 [Josh Rosen] Refactor logic for deciding when to bypass
4b03539 [Josh Rosen] Move conf prior to first use
1265b25 [Josh Rosen] Fix some style errors and comments.
02355ef [Josh Rosen] More simplification
d4cb536 [Josh Rosen] Delete more unused code
bb96678 [Josh Rosen] Add missing interface file
b6cc1eb [Josh Rosen] Realize that bypass never buffers; proceed to delete tons of code
6185ee2 [Josh Rosen] WIP towards moving bypass code into own file.
8d0678c [Josh Rosen] Move diskBytesSpilled getter next to variable
19bccd6 [Josh Rosen] Remove duplicated buffer creation code.
18959bb [Josh Rosen] Move comparator methods closer together.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6430028
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6430028
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6430028

Branch: refs/heads/master
Commit: a6430028ecd7a6130f1eb15af9ec00e242c46725
Parents: 7716a5a
Author: Josh Rosen <jo...@databricks.com>
Authored: Sat May 30 15:27:51 2015 -0700
Committer: Reynold Xin <rx...@databricks.com>
Committed: Sat May 30 15:27:51 2015 -0700

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      | 184 +++++++++++++
 .../shuffle/sort/SortShuffleFileWriter.java     |  53 ++++
 .../spark/shuffle/sort/SortShuffleWriter.scala  |  34 ++-
 .../spark/storage/BlockObjectWriter.scala       |  19 +-
 .../org/apache/spark/storage/FileSegment.scala  |   2 +
 .../spark/util/collection/ExternalSorter.scala  | 260 +++++--------------
 .../spark/util/collection/PairIterator.scala    |  24 --
 .../collection/PartitionedAppendOnlyMap.scala   |   4 -
 .../util/collection/PartitionedPairBuffer.scala |   4 -
 .../PartitionedSerializedPairBuffer.scala       |   4 -
 .../WritablePartitionedPairCollection.scala     |  36 +--
 .../scala/org/apache/spark/ShuffleSuite.scala   |  65 +++++
 .../spark/metrics/InputOutputMetricsSuite.scala |  28 --
 .../BypassMergeSortShuffleWriterSuite.scala     | 171 ++++++++++++
 .../shuffle/sort/SortShuffleWriterSuite.scala   |  46 ++++
 .../spark/storage/BlockObjectWriterSuite.scala  |  97 ++++++-
 .../util/collection/ExternalSorterSuite.scala   | 130 +---------
 17 files changed, 738 insertions(+), 423 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
new file mode 100644
index 0000000..d3d6280
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.Tuple2;
+import scala.collection.Iterator;
+
+import com.google.common.io.Closeables;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.Partitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.serializer.Serializer;
+import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.storage.*;
+import org.apache.spark.util.Utils;
+
+/**
+ * This class implements sort-based shuffle's hash-style shuffle fallback path. This write path
+ * writes incoming records to separate files, one file per reduce partition, then concatenates these
+ * per-partition files to form a single output file, regions of which are served to reducers.
+ * Records are not buffered in memory. This is essentially identical to
+ * {@link org.apache.spark.shuffle.hash.HashShuffleWriter}, except that it writes output in a format
+ * that can be served / consumed via {@link org.apache.spark.shuffle.IndexShuffleBlockResolver}.
+ * <p>
+ * This write path is inefficient for shuffles with large numbers of reduce partitions because it
+ * simultaneously opens separate serializers and file streams for all partitions. As a result,
+ * {@link SortShuffleManager} only selects this write path when
+ * <ul>
+ *    <li>no Ordering is specified,</li>
+ *    <li>no Aggregator is specific, and</li>
+ *    <li>the number of partitions is less than
+ *      <code>spark.shuffle.sort.bypassMergeThreshold</code>.</li>
+ * </ul>
+ *
+ * This code used to be part of {@link org.apache.spark.util.collection.ExternalSorter} but was
+ * refactored into its own class in order to reduce code complexity; see SPARK-7855 for details.
+ * <p>
+ * There have been proposals to completely remove this code path; see SPARK-6026 for details.
+ */
+final class BypassMergeSortShuffleWriter<K, V> implements SortShuffleFileWriter<K, V> {
+
+  private final Logger logger = LoggerFactory.getLogger(BypassMergeSortShuffleWriter.class);
+
+  private final int fileBufferSize;
+  private final boolean transferToEnabled;
+  private final int numPartitions;
+  private final BlockManager blockManager;
+  private final Partitioner partitioner;
+  private final ShuffleWriteMetrics writeMetrics;
+  private final Serializer serializer;
+
+  /** Array of file writers, one for each partition */
+  private BlockObjectWriter[] partitionWriters;
+
+  public BypassMergeSortShuffleWriter(
+      SparkConf conf,
+      BlockManager blockManager,
+      Partitioner partitioner,
+      ShuffleWriteMetrics writeMetrics,
+      Serializer serializer) {
+    // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
+    this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
+    this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
+    this.numPartitions = partitioner.numPartitions();
+    this.blockManager = blockManager;
+    this.partitioner = partitioner;
+    this.writeMetrics = writeMetrics;
+    this.serializer = serializer;
+  }
+
+  @Override
+  public void insertAll(Iterator<Product2<K, V>> records) throws IOException {
+    assert (partitionWriters == null);
+    if (!records.hasNext()) {
+      return;
+    }
+    final SerializerInstance serInstance = serializer.newInstance();
+    final long openStartTime = System.nanoTime();
+    partitionWriters = new BlockObjectWriter[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
+        blockManager.diskBlockManager().createTempShuffleBlock();
+      final File file = tempShuffleBlockIdPlusFile._2();
+      final BlockId blockId = tempShuffleBlockIdPlusFile._1();
+      partitionWriters[i] =
+        blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics).open();
+    }
+    // Creating the file to write to and creating a disk writer both involve interacting with
+    // the disk, and can take a long time in aggregate when we open many files, so should be
+    // included in the shuffle write time.
+    writeMetrics.incShuffleWriteTime(System.nanoTime() - openStartTime);
+
+    while (records.hasNext()) {
+      final Product2<K, V> record = records.next();
+      final K key = record._1();
+      partitionWriters[partitioner.getPartition(key)].write(key, record._2());
+    }
+
+    for (BlockObjectWriter writer : partitionWriters) {
+      writer.commitAndClose();
+    }
+  }
+
+  @Override
+  public long[] writePartitionedFile(
+      BlockId blockId,
+      TaskContext context,
+      File outputFile) throws IOException {
+    // Track location of the partition starts in the output file
+    final long[] lengths = new long[numPartitions];
+    if (partitionWriters == null) {
+      // We were passed an empty iterator
+      return lengths;
+    }
+
+    final FileOutputStream out = new FileOutputStream(outputFile, true);
+    final long writeStartTime = System.nanoTime();
+    boolean threwException = true;
+    try {
+      for (int i = 0; i < numPartitions; i++) {
+        final FileInputStream in = new FileInputStream(partitionWriters[i].fileSegment().file());
+        boolean copyThrewException = true;
+        try {
+          lengths[i] = Utils.copyStream(in, out, false, transferToEnabled);
+          copyThrewException = false;
+        } finally {
+          Closeables.close(in, copyThrewException);
+        }
+        if (!blockManager.diskBlockManager().getFile(partitionWriters[i].blockId()).delete()) {
+          logger.error("Unable to delete file for partition {}", i);
+        }
+      }
+      threwException = false;
+    } finally {
+      Closeables.close(out, threwException);
+      writeMetrics.incShuffleWriteTime(System.nanoTime() - writeStartTime);
+    }
+    partitionWriters = null;
+    return lengths;
+  }
+
+  @Override
+  public void stop() throws IOException {
+    if (partitionWriters != null) {
+      try {
+        final DiskBlockManager diskBlockManager = blockManager.diskBlockManager();
+        for (BlockObjectWriter writer : partitionWriters) {
+          // This method explicitly does _not_ throw exceptions:
+          writer.revertPartialWritesAndClose();
+          if (!diskBlockManager.getFile(writer.blockId()).delete()) {
+            logger.error("Error while deleting file for block {}", writer.blockId());
+          }
+        }
+      } finally {
+        partitionWriters = null;
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
new file mode 100644
index 0000000..656ea04
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/SortShuffleFileWriter.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort;
+
+import java.io.File;
+import java.io.IOException;
+
+import scala.Product2;
+import scala.collection.Iterator;
+
+import org.apache.spark.annotation.Private;
+import org.apache.spark.TaskContext;
+import org.apache.spark.storage.BlockId;
+
+/**
+ * Interface for objects that {@link SortShuffleWriter} uses to write its output files.
+ */
+@Private
+public interface SortShuffleFileWriter<K, V> {
+
+  void insertAll(Iterator<Product2<K, V>> records) throws IOException;
+
+  /**
+   * Write all the data added into this shuffle sorter into a file in the disk store. This is
+   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
+   * binary files if we decided to avoid merge-sorting.
+   *
+   * @param blockId block ID to write to. The index file will be blockId.name + ".index".
+   * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
+   * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
+   */
+  long[] writePartitionedFile(
+      BlockId blockId,
+      TaskContext context,
+      File outputFile) throws IOException;
+
+  void stop() throws IOException;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index c9dd6bf..5865e76 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.shuffle.sort
 
-import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
 import org.apache.spark.storage.ShuffleBlockId
 import org.apache.spark.util.collection.ExternalSorter
@@ -35,7 +36,7 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private val blockManager = SparkEnv.get.blockManager
 
-  private var sorter: ExternalSorter[K, V, _] = null
+  private var sorter: SortShuffleFileWriter[K, V] = null
 
   // Are we in the process of stopping? Because map tasks can call stop() with success = true
   // and then call stop() with success = false if they get an exception, we want to make sure
@@ -49,18 +50,27 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[Product2[K, V]]): Unit = {
-    if (dep.mapSideCombine) {
+    sorter = if (dep.mapSideCombine) {
       require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
-      sorter = new ExternalSorter[K, V, C](
+      new ExternalSorter[K, V, C](
         dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
-      sorter.insertAll(records)
+    } else if (SortShuffleWriter.shouldBypassMergeSort(
+        SparkEnv.get.conf, dep.partitioner.numPartitions, aggregator = None, keyOrdering = None)) {
+      // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't
+      // need local aggregation and sorting, write numPartitions files directly and just concatenate
+      // them at the end. This avoids doing serialization and deserialization twice to merge
+      // together the spilled files, which would happen with the normal code path. The downside is
+      // having multiple files open at a time and thus more memory allocated to buffers.
+      new BypassMergeSortShuffleWriter[K, V](SparkEnv.get.conf, blockManager, dep.partitioner,
+        writeMetrics, Serializer.getSerializer(dep.serializer))
     } else {
       // In this case we pass neither an aggregator nor an ordering to the sorter, because we don't
       // care whether the keys get sorted in each partition; that will be done on the reduce side
       // if the operation being run is sortByKey.
-      sorter = new ExternalSorter[K, V, V](None, Some(dep.partitioner), None, dep.serializer)
-      sorter.insertAll(records)
+      new ExternalSorter[K, V, V](
+        aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
     }
+    sorter.insertAll(records)
 
     // Don't bother including the time to open the merged output file in the shuffle write time,
     // because it just opens a single file, so is typically too fast to measure accurately
@@ -100,3 +110,13 @@ private[spark] class SortShuffleWriter[K, V, C](
   }
 }
 
+private[spark] object SortShuffleWriter {
+  def shouldBypassMergeSort(
+      conf: SparkConf,
+      numPartitions: Int,
+      aggregator: Option[Aggregator[_, _, _]],
+      keyOrdering: Option[Ordering[_]]): Boolean = {
+    val bypassMergeThreshold: Int = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
+    numPartitions <= bypassMergeThreshold && aggregator.isEmpty && keyOrdering.isEmpty
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index a33f22e..7eeabd1 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -95,6 +95,7 @@ private[spark] class DiskBlockObjectWriter(
   private var objOut: SerializationStream = null
   private var initialized = false
   private var hasBeenClosed = false
+  private var commitAndCloseHasBeenCalled = false
 
   /**
    * Cursors used to represent positions in the file.
@@ -167,20 +168,22 @@ private[spark] class DiskBlockObjectWriter(
       objOut.flush()
       bs.flush()
       close()
+      finalPosition = file.length()
+      // In certain compression codecs, more bytes are written after close() is called
+      writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+    } else {
+      finalPosition = file.length()
     }
-    finalPosition = file.length()
-    // In certain compression codecs, more bytes are written after close() is called
-    writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
+    commitAndCloseHasBeenCalled = true
   }
 
   // Discard current writes. We do this by flushing the outstanding writes and then
   // truncating the file to its initial position.
   override def revertPartialWritesAndClose() {
     try {
-      writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
-      writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
-
       if (initialized) {
+        writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+        writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
         objOut.flush()
         bs.flush()
         close()
@@ -228,6 +231,10 @@ private[spark] class DiskBlockObjectWriter(
   }
 
   override def fileSegment(): FileSegment = {
+    if (!commitAndCloseHasBeenCalled) {
+      throw new IllegalStateException(
+        "fileSegment() is only valid after commitAndClose() has been called")
+    }
     new FileSegment(file, initialPosition, finalPosition - initialPosition)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
index 95e2d68..021a9fa 100644
--- a/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
+++ b/core/src/main/scala/org/apache/spark/storage/FileSegment.scala
@@ -24,6 +24,8 @@ import java.io.File
  * based off an offset and a length.
  */
 private[spark] class FileSegment(val file: File, val offset: Long, val length: Long) {
+  require(offset >= 0, s"File segment offset cannot be negative (got $offset)")
+  require(length >= 0, s"File segment length cannot be negative (got $length)")
   override def toString: String = {
     "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3b9d14f..ef2dbb7 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -23,12 +23,14 @@ import java.util.Comparator
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable
 
+import com.google.common.annotations.VisibleForTesting
 import com.google.common.io.ByteStreams
 
 import org.apache.spark._
 import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.storage.{BlockObjectWriter, BlockId}
+import org.apache.spark.shuffle.sort.{SortShuffleFileWriter, SortShuffleWriter}
+import org.apache.spark.storage.{BlockId, BlockObjectWriter}
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
@@ -84,35 +86,40 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
  *   each other for equality to merge values.
  *
  * - Users are expected to call stop() at the end to delete all the intermediate files.
- *
- * As a special case, if no Ordering and no Aggregator is given, and the number of partitions is
- * less than spark.shuffle.sort.bypassMergeThreshold, we bypass the merge-sort and just write to
- * separate files for each partition each time we spill, similar to the HashShuffleWriter. We can
- * then concatenate these files to produce a single sorted file, without having to serialize and
- * de-serialize each item twice (as is needed during the merge). This speeds up the map side of
- * groupBy, sort, etc operations since they do no partial aggregation.
  */
 private[spark] class ExternalSorter[K, V, C](
     aggregator: Option[Aggregator[K, V, C]] = None,
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
     serializer: Option[Serializer] = None)
-  extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
+  extends Logging
+  with Spillable[WritablePartitionedPairCollection[K, C]]
+  with SortShuffleFileWriter[K, V] {
+
+  private val conf = SparkEnv.get.conf
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
   private val shouldPartition = numPartitions > 1
+  private def getPartition(key: K): Int = {
+    if (shouldPartition) partitioner.get.getPartition(key) else 0
+  }
+
+  // Since SPARK-7855, bypassMergeSort optimization is no longer performed as part of this class.
+  // As a sanity check, make sure that we're not handling a shuffle which should use that path.
+  if (SortShuffleWriter.shouldBypassMergeSort(conf, numPartitions, aggregator, ordering)) {
+    throw new IllegalArgumentException("ExternalSorter should not be used to handle "
+      + " a sort that the BypassMergeSortShuffleWriter should handle")
+  }
 
   private val blockManager = SparkEnv.get.blockManager
   private val diskBlockManager = blockManager.diskBlockManager
   private val ser = Serializer.getSerializer(serializer)
   private val serInstance = ser.newInstance()
 
-  private val conf = SparkEnv.get.conf
   private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
   
   // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
   private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
-  private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
 
   // Size of object batches when reading/writing from serializers.
   //
@@ -123,43 +130,28 @@ private[spark] class ExternalSorter[K, V, C](
   // grow internal data structures by growing + copying every time the number of objects doubles.
   private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
 
-  private def getPartition(key: K): Int = {
-    if (shouldPartition) partitioner.get.getPartition(key) else 0
-  }
-
-  private val metaInitialRecords = 256
-  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
   private val useSerializedPairBuffer =
-    !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
-    ser.supportsRelocationOfSerializedObjects
-
+    ordering.isEmpty &&
+      conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+      ser.supportsRelocationOfSerializedObjects
+  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+  private def newBuffer(): WritablePartitionedPairCollection[K, C] with SizeTracker = {
+    if (useSerializedPairBuffer) {
+      new PartitionedSerializedPairBuffer(metaInitialRecords = 256, kvChunkSize, serInstance)
+    } else {
+      new PartitionedPairBuffer[K, C]
+    }
+  }
   // Data structures to store in-memory objects before we spill. Depending on whether we have an
   // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
   // store them in an array buffer.
   private var map = new PartitionedAppendOnlyMap[K, C]
-  private var buffer = if (useSerializedPairBuffer) {
-    new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
-  } else {
-    new PartitionedPairBuffer[K, C]
-  }
+  private var buffer = newBuffer()
 
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
+  def diskBytesSpilled: Long = _diskBytesSpilled
 
-  // Write metrics for current spill
-  private var curWriteMetrics: ShuffleWriteMetrics = _
-
-  // If there are fewer than spark.shuffle.sort.bypassMergeThreshold partitions and we don't need
-  // local aggregation and sorting, write numPartitions files directly and just concatenate them
-  // at the end. This avoids doing serialization and deserialization twice to merge together the
-  // spilled files, which would happen with the normal code path. The downside is having multiple
-  // files open at a time and thus more memory allocated to buffers.
-  private val bypassMergeThreshold = conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
-  private val bypassMergeSort =
-    (numPartitions <= bypassMergeThreshold && aggregator.isEmpty && ordering.isEmpty)
-
-  // Array of file writers for each partition, used if bypassMergeSort is true and we've spilled
-  private var partitionWriters: Array[BlockObjectWriter] = null
 
   // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
   // Can be a partial ordering by hash code if a total ordering is not provided through by the
@@ -174,6 +166,14 @@ private[spark] class ExternalSorter[K, V, C](
     }
   })
 
+  private def comparator: Option[Comparator[K]] = {
+    if (ordering.isDefined || aggregator.isDefined) {
+      Some(keyComparator)
+    } else {
+      None
+    }
+  }
+
   // Information about a spilled file. Includes sizes in bytes of "batches" written by the
   // serializer as we periodically reset its stream, as well as number of elements in each
   // partition, used to efficiently keep track of partitions when merging.
@@ -182,9 +182,10 @@ private[spark] class ExternalSorter[K, V, C](
     blockId: BlockId,
     serializerBatchSizes: Array[Long],
     elementsPerPartition: Array[Long])
+
   private val spills = new ArrayBuffer[SpilledFile]
 
-  def insertAll(records: Iterator[_ <: Product2[K, V]]): Unit = {
+  override def insertAll(records: Iterator[Product2[K, V]]): Unit = {
     // TODO: stop combining if we find that the reduction factor isn't high
     val shouldCombine = aggregator.isDefined
 
@@ -202,15 +203,6 @@ private[spark] class ExternalSorter[K, V, C](
         map.changeValue((getPartition(kv._1), kv._1), update)
         maybeSpillCollection(usingMap = true)
       }
-    } else if (bypassMergeSort) {
-      // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
-      if (records.hasNext) {
-        spillToPartitionFiles(
-          WritablePartitionedIterator.fromIterator(records.map { kv =>
-            ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
-          })
-        )
-      }
     } else {
       // Stick values into our buffer
       while (records.hasNext) {
@@ -238,46 +230,33 @@ private[spark] class ExternalSorter[K, V, C](
       }
     } else {
       if (maybeSpill(buffer, buffer.estimateSize())) {
-        buffer = if (useSerializedPairBuffer) {
-          new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
-        } else {
-          new PartitionedPairBuffer[K, C]
-        }
+        buffer = newBuffer()
       }
     }
   }
 
   /**
-   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
-   */
-  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    if (bypassMergeSort) {
-      spillToPartitionFiles(collection)
-    } else {
-      spillToMergeableFile(collection)
-    }
-  }
-
-  /**
-   * Spill our in-memory collection to a sorted file that we can merge later (normal code path).
-   * We add this file into spilledFiles to find it later.
-   *
-   * This should not be invoked if bypassMergeSort is true. In that case, spillToPartitionedFiles()
-   * is used to write files for each partition.
+   * Spill our in-memory collection to a sorted file that we can merge later.
+   * We add this file into `spilledFiles` to find it later.
    *
    * @param collection whichever collection we're using (map or buffer)
    */
-  private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    assert(!bypassMergeSort)
-
+  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
     // Because these files may be read during shuffle, their compression must be controlled by
     // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
     // createTempShuffleBlock here; see SPARK-3426 for more context.
     val (blockId, file) = diskBlockManager.createTempShuffleBlock()
-    curWriteMetrics = new ShuffleWriteMetrics()
-    var writer = blockManager.getDiskWriter(
-      blockId, file, serInstance, fileBufferSize, curWriteMetrics)
-    var objectsWritten = 0   // Objects written since the last flush
+
+    // These variables are reset after each flush
+    var objectsWritten: Long = 0
+    var spillMetrics: ShuffleWriteMetrics = null
+    var writer: BlockObjectWriter = null
+    def openWriter(): Unit = {
+      assert (writer == null && spillMetrics == null)
+      spillMetrics = new ShuffleWriteMetrics
+      writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics)
+    }
+    openWriter()
 
     // List of batch sizes (bytes) in the order they are written to disk
     val batchSizes = new ArrayBuffer[Long]
@@ -291,8 +270,9 @@ private[spark] class ExternalSorter[K, V, C](
       val w = writer
       writer = null
       w.commitAndClose()
-      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
-      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
+      _diskBytesSpilled += spillMetrics.shuffleBytesWritten
+      batchSizes.append(spillMetrics.shuffleBytesWritten)
+      spillMetrics = null
       objectsWritten = 0
     }
 
@@ -307,9 +287,7 @@ private[spark] class ExternalSorter[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          curWriteMetrics = new ShuffleWriteMetrics()
-          writer = blockManager.getDiskWriter(
-            blockId, file, serInstance, fileBufferSize, curWriteMetrics)
+          openWriter()
         }
       }
       if (objectsWritten > 0) {
@@ -337,46 +315,6 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   /**
-   * Spill our in-memory collection to separate files, one for each partition. This is used when
-   * there's no aggregator and ordering and the number of partitions is small, because it allows
-   * writePartitionedFile to just concatenate files without deserializing data.
-   *
-   * @param collection whichever collection we're using (map or buffer)
-   */
-  private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
-    spillToPartitionFiles(collection.writablePartitionedIterator())
-  }
-
-  private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
-    assert(bypassMergeSort)
-
-    // Create our file writers if we haven't done so yet
-    if (partitionWriters == null) {
-      curWriteMetrics = new ShuffleWriteMetrics()
-      val openStartTime = System.nanoTime
-      partitionWriters = Array.fill(numPartitions) {
-        // Because these files may be read during shuffle, their compression must be controlled by
-        // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
-        // createTempShuffleBlock here; see SPARK-3426 for more context.
-        val (blockId, file) = diskBlockManager.createTempShuffleBlock()
-        val writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize,
-          curWriteMetrics)
-        writer.open()
-      }
-      // Creating the file to write to and creating a disk writer both involve interacting with
-      // the disk, and can take a long time in aggregate when we open many files, so should be
-      // included in the shuffle write time.
-      curWriteMetrics.incShuffleWriteTime(System.nanoTime - openStartTime)
-    }
-
-    // No need to sort stuff, just write each element out
-    while (iterator.hasNext) {
-      val partitionId = iterator.nextPartition()
-      iterator.writeNext(partitionWriters(partitionId))
-    }
-  }
-
-  /**
    * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
    * inside each partition. This can be used to either write out a new file or return data to
    * the user.
@@ -665,8 +603,6 @@ private[spark] class ExternalSorter[K, V, C](
   }
 
   /**
-   * Exposed for testing purposes.
-   *
    * Return an iterator over all the data written to this object, grouped by partition and
    * aggregated by the requested aggregator. For each partition we then have an iterator over its
    * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
@@ -676,10 +612,11 @@ private[spark] class ExternalSorter[K, V, C](
    * For now, we just merge all the spilled files in once pass, but this can be modified to
    * support hierarchical merging.
    */
-   def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+  @VisibleForTesting
+  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
     val usingMap = aggregator.isDefined
     val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
-    if (spills.isEmpty && partitionWriters == null) {
+    if (spills.isEmpty) {
       // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
       // we don't even need to sort by anything other than partition ID
       if (!ordering.isDefined) {
@@ -689,13 +626,6 @@ private[spark] class ExternalSorter[K, V, C](
         // We do need to sort by both partition ID and key
         groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
       }
-    } else if (bypassMergeSort) {
-      // Read data from each partition file and merge it together with the data in memory;
-      // note that there's no ordering or aggregator in this case -- we just partition objects
-      val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
-      collIter.map { case (partitionId, values) =>
-        (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
-      }
     } else {
       // Merge spilled and in-memory data
       merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
@@ -709,14 +639,13 @@ private[spark] class ExternalSorter[K, V, C](
 
   /**
    * Write all the data added into this ExternalSorter into a file in the disk store. This is
-   * called by the SortShuffleWriter and can go through an efficient path of just concatenating
-   * binary files if we decided to avoid merge-sorting.
+   * called by the SortShuffleWriter.
    *
    * @param blockId block ID to write to. The index file will be blockId.name + ".index".
    * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
    * @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
    */
-  def writePartitionedFile(
+  override def writePartitionedFile(
       blockId: BlockId,
       context: TaskContext,
       outputFile: File): Array[Long] = {
@@ -724,28 +653,7 @@ private[spark] class ExternalSorter[K, V, C](
     // Track location of each range in the output file
     val lengths = new Array[Long](numPartitions)
 
-    if (bypassMergeSort && partitionWriters != null) {
-      // We decided to write separate files for each partition, so just concatenate them. To keep
-      // this simple we spill out the current in-memory collection so that everything is in files.
-      spillToPartitionFiles(if (aggregator.isDefined) map else buffer)
-      partitionWriters.foreach(_.commitAndClose())
-      val out = new FileOutputStream(outputFile, true)
-      val writeStartTime = System.nanoTime
-      util.Utils.tryWithSafeFinally {
-        for (i <- 0 until numPartitions) {
-          val in = new FileInputStream(partitionWriters(i).fileSegment().file)
-          util.Utils.tryWithSafeFinally {
-            lengths(i) = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled)
-          } {
-            in.close()
-          }
-        }
-      } {
-        out.close()
-        context.taskMetrics.shuffleWriteMetrics.foreach(
-          _.incShuffleWriteTime(System.nanoTime - writeStartTime))
-      }
-    } else if (spills.isEmpty && partitionWriters == null) {
+    if (spills.isEmpty) {
       // Case where we only have in-memory data
       val collection = if (aggregator.isDefined) map else buffer
       val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
@@ -761,7 +669,7 @@ private[spark] class ExternalSorter[K, V, C](
         lengths(partitionId) = segment.length
       }
     } else {
-      // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
+      // We must perform merge-sort; get an iterator by partition and write everything directly.
       for ((id, elements) <- this.partitionedIterator) {
         if (elements.hasNext) {
           val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
@@ -778,41 +686,15 @@ private[spark] class ExternalSorter[K, V, C](
 
     context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
     context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
-    context.taskMetrics.shuffleWriteMetrics.filter(_ => bypassMergeSort).foreach { m =>
-      if (curWriteMetrics != null) {
-        m.incShuffleBytesWritten(curWriteMetrics.shuffleBytesWritten)
-        m.incShuffleWriteTime(curWriteMetrics.shuffleWriteTime)
-        m.incShuffleRecordsWritten(curWriteMetrics.shuffleRecordsWritten)
-      }
-    }
 
     lengths
   }
 
-  /**
-   * Read a partition file back as an iterator (used in our iterator method)
-   */
-  private def readPartitionFile(writer: BlockObjectWriter): Iterator[Product2[K, C]] = {
-    if (writer.isOpen) {
-      writer.commitAndClose()
-    }
-    new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
-  }
-
   def stop(): Unit = {
     spills.foreach(s => s.file.delete())
     spills.clear()
-    if (partitionWriters != null) {
-      partitionWriters.foreach { w =>
-        w.revertPartialWritesAndClose()
-        diskBlockManager.getFile(w.blockId).delete()
-      }
-      partitionWriters = null
-    }
   }
 
-  def diskBytesSpilled: Long = _diskBytesSpilled
-
   /**
    * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
    * group together the pairs for each partition into a sub-iterator.
@@ -826,14 +708,6 @@ private[spark] class ExternalSorter[K, V, C](
     (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
   }
 
-  private def comparator: Option[Comparator[K]] = {
-    if (ordering.isDefined || aggregator.isDefined) {
-      Some(keyComparator)
-    } else {
-      None
-    }
-  }
-
   /**
    * An iterator that reads only the elements for a given partition ID from an underlying buffered
    * stream, assuming this partition is the next one to be read. Used to make it easier to return

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
deleted file mode 100644
index d75959f..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
+++ /dev/null
@@ -1,24 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
-  def hasNext: Boolean = iter.hasNext
-
-  def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
index e2e2f1f..d0d25b4 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -34,10 +34,6 @@ private[spark] class PartitionedAppendOnlyMap[K, V]
     destructiveSortedIterator(comparator)
   }
 
-  def writablePartitionedIterator(): WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(super.iterator)
-  }
-
   def insert(partition: Int, key: K, value: V): Unit = {
     update((partition, key), value)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
index e8332e1..5a6e9a9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -71,10 +71,6 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
     iterator
   }
 
-  override def writablePartitionedIterator(): WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(iterator)
-  }
-
   private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
     var pos = 0
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
index 554d882..862408b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -122,10 +122,6 @@ private[spark] class PartitionedSerializedPairBuffer[K, V](
   override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
     : WritablePartitionedIterator = {
     sort(keyComparator)
-    writablePartitionedIterator
-  }
-
-  override def writablePartitionedIterator(): WritablePartitionedIterator = {
     new WritablePartitionedIterator {
       // current position in the meta buffer in ints
       var pos = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
index f26d161..7bc5989 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -47,13 +47,20 @@ private[spark] trait WritablePartitionedPairCollection[K, V] {
    */
   def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
     : WritablePartitionedIterator = {
-    WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
-  }
+    val it = partitionedDestructiveSortedIterator(keyComparator)
+    new WritablePartitionedIterator {
+      private[this] var cur = if (it.hasNext) it.next() else null
 
-  /**
-   * Iterate through the data and write out the elements instead of returning them.
-   */
-  def writablePartitionedIterator(): WritablePartitionedIterator
+      def writeNext(writer: BlockObjectWriter): Unit = {
+        writer.write(cur._1._2, cur._2)
+        cur = if (it.hasNext) it.next() else null
+      }
+
+      def hasNext(): Boolean = cur != null
+
+      def nextPartition(): Int = cur._1._1
+    }
+  }
 }
 
 private[spark] object WritablePartitionedPairCollection {
@@ -94,20 +101,3 @@ private[spark] trait WritablePartitionedIterator {
 
   def nextPartition(): Int
 }
-
-private[spark] object WritablePartitionedIterator {
-  def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
-    new WritablePartitionedIterator {
-      var cur = if (it.hasNext) it.next() else null
-
-      def writeNext(writer: BlockObjectWriter): Unit = {
-        writer.write(cur._1._2, cur._2)
-        cur = if (it.hasNext) it.next() else null
-      }
-
-      def hasNext(): Boolean = cur != null
-
-      def nextPartition(): Int = cur._1._1
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 91f4ab3..c3c2b1f 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.Matchers
 
 import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
 import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
 import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
 import org.apache.spark.util.MutablePair
@@ -281,6 +282,39 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     // This count should retry the execution of the previous stage and rerun shuffle.
     rdd.count()
   }
+
+  test("metrics for shuffle without aggregation") {
+    sc = new SparkContext("local", "test", conf.clone())
+    val numRecords = 10000
+
+    val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+      sc.parallelize(1 to numRecords, 4)
+        .map(key => (key, 1))
+        .groupByKey()
+        .collect()
+    }
+
+    assert(metrics.recordsRead === numRecords)
+    assert(metrics.recordsWritten === numRecords)
+    assert(metrics.bytesWritten === metrics.byresRead)
+    assert(metrics.bytesWritten > 0)
+  }
+
+  test("metrics for shuffle with aggregation") {
+    sc = new SparkContext("local", "test", conf.clone())
+    val numRecords = 10000
+
+    val metrics = ShuffleSuite.runAndReturnMetrics(sc) {
+      sc.parallelize(1 to numRecords, 4)
+        .flatMap(key => Array.fill(100)((key, 1)))
+        .countByKey()
+    }
+
+    assert(metrics.recordsRead === numRecords)
+    assert(metrics.recordsWritten === numRecords)
+    assert(metrics.bytesWritten === metrics.byresRead)
+    assert(metrics.bytesWritten > 0)
+  }
 }
 
 object ShuffleSuite {
@@ -294,4 +328,35 @@ object ShuffleSuite {
       value - o.value
     }
   }
+
+  case class AggregatedShuffleMetrics(
+    recordsWritten: Long,
+    recordsRead: Long,
+    bytesWritten: Long,
+    byresRead: Long)
+
+  def runAndReturnMetrics(sc: SparkContext)(job: => Unit): AggregatedShuffleMetrics = {
+    @volatile var recordsWritten: Long = 0
+    @volatile var recordsRead: Long = 0
+    @volatile var bytesWritten: Long = 0
+    @volatile var bytesRead: Long = 0
+    val listener = new SparkListener {
+      override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+        taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m =>
+          recordsWritten += m.shuffleRecordsWritten
+          bytesWritten += m.shuffleBytesWritten
+        }
+        taskEnd.taskMetrics.shuffleReadMetrics.foreach { m =>
+          recordsRead += m.recordsRead
+          bytesRead += m.totalBytesRead
+        }
+      }
+    }
+    sc.addSparkListener(listener)
+
+    job
+
+    sc.listenerBus.waitUntilEmpty(500)
+    AggregatedShuffleMetrics(recordsWritten, recordsRead, bytesWritten, bytesRead)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
index 19f1af0..9e4d34f 100644
--- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala
@@ -193,26 +193,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
     assert(records == numRecords)
   }
 
-  test("shuffle records read metrics") {
-    val recordsRead = runAndReturnShuffleRecordsRead {
-      sc.textFile(tmpFilePath, 4)
-        .map(key => (key, 1))
-        .groupByKey()
-        .collect()
-    }
-    assert(recordsRead == numRecords)
-  }
-
-  test("shuffle records written metrics") {
-    val recordsWritten = runAndReturnShuffleRecordsWritten {
-      sc.textFile(tmpFilePath, 4)
-        .map(key => (key, 1))
-        .groupByKey()
-        .collect()
-    }
-    assert(recordsWritten == numRecords)
-  }
-
   /**
    * Tests the metrics from end to end.
    * 1) reading a hadoop file
@@ -301,14 +281,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext
     runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten))
   }
 
-  private def runAndReturnShuffleRecordsRead(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.shuffleReadMetrics.map(_.recordsRead))
-  }
-
-  private def runAndReturnShuffleRecordsWritten(job: => Unit): Long = {
-    runAndReturnMetrics(job, _.taskMetrics.shuffleWriteMetrics.map(_.shuffleRecordsWritten))
-  }
-
   private def runAndReturnMetrics(job: => Unit,
       collector: (SparkListenerTaskEnd) => Option[Long]): Long = {
     val taskMetrics = new ArrayBuffer[Long]()

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
new file mode 100644
index 0000000..c8420db
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.File
+import java.util.UUID
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.scalatest.{BeforeAndAfterEach, FunSuite}
+
+import org.apache.spark._
+import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics}
+import org.apache.spark.serializer.{SerializerInstance, Serializer, JavaSerializer}
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+
+class BypassMergeSortShuffleWriterSuite extends FunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var taskContext: TaskContext = _
+
+  private var taskMetrics: TaskMetrics = _
+  private var shuffleWriteMetrics: ShuffleWriteMetrics = _
+  private var tempDir: File = _
+  private var outputFile: File = _
+  private val conf: SparkConf = new SparkConf(loadDefaults = false)
+  private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]()
+  private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File]
+  private val shuffleBlockId: ShuffleBlockId = new ShuffleBlockId(0, 0, 0)
+  private val serializer: Serializer = new JavaSerializer(conf)
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+    outputFile = File.createTempFile("shuffle", null, tempDir)
+    shuffleWriteMetrics = new ShuffleWriteMetrics
+    taskMetrics = new TaskMetrics
+    taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
+    MockitoAnnotations.initMocks(this)
+    when(taskContext.taskMetrics()).thenReturn(taskMetrics)
+    when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+    when(blockManager.getDiskWriter(
+      any[BlockId],
+      any[File],
+      any[SerializerInstance],
+      anyInt(),
+      any[ShuffleWriteMetrics]
+    )).thenAnswer(new Answer[BlockObjectWriter] {
+      override def answer(invocation: InvocationOnMock): BlockObjectWriter = {
+        val args = invocation.getArguments
+        new DiskBlockObjectWriter(
+          args(0).asInstanceOf[BlockId],
+          args(1).asInstanceOf[File],
+          args(2).asInstanceOf[SerializerInstance],
+          args(3).asInstanceOf[Int],
+          compressStream = identity,
+          syncWrites = false,
+          args(4).asInstanceOf[ShuffleWriteMetrics]
+        )
+      }
+    })
+    when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
+      new Answer[(TempShuffleBlockId, File)] {
+        override def answer(invocation: InvocationOnMock): (TempShuffleBlockId, File) = {
+          val blockId = new TempShuffleBlockId(UUID.randomUUID)
+          val file = File.createTempFile(blockId.toString, null, tempDir)
+          blockIdToFileMap.put(blockId, file)
+          temporaryFilesCreated.append(file)
+          (blockId, file)
+        }
+      })
+    when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+      new Answer[File] {
+        override def answer(invocation: InvocationOnMock): File = {
+          blockIdToFileMap.get(invocation.getArguments.head.asInstanceOf[BlockId]).get
+        }
+    })
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+    blockIdToFileMap.clear()
+    temporaryFilesCreated.clear()
+  }
+
+  test("write empty iterator") {
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    writer.insertAll(Iterator.empty)
+    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+    assert(partitionLengths.sum === 0)
+    assert(outputFile.exists())
+    assert(outputFile.length() === 0)
+    assert(temporaryFilesCreated.isEmpty)
+    assert(shuffleWriteMetrics.shuffleBytesWritten === 0)
+    assert(shuffleWriteMetrics.shuffleRecordsWritten === 0)
+    assert(taskMetrics.diskBytesSpilled === 0)
+    assert(taskMetrics.memoryBytesSpilled === 0)
+  }
+
+  test("write with some empty partitions") {
+    def records: Iterator[(Int, Int)] =
+      Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    writer.insertAll(records)
+    assert(temporaryFilesCreated.nonEmpty)
+    val partitionLengths = writer.writePartitionedFile(shuffleBlockId, taskContext, outputFile)
+    assert(partitionLengths.sum === outputFile.length())
+    assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted
+    assert(shuffleWriteMetrics.shuffleBytesWritten === outputFile.length())
+    assert(shuffleWriteMetrics.shuffleRecordsWritten === records.length)
+    assert(taskMetrics.diskBytesSpilled === 0)
+    assert(taskMetrics.memoryBytesSpilled === 0)
+  }
+
+  test("cleanup of intermediate files after errors") {
+    val writer = new BypassMergeSortShuffleWriter[Int, Int](
+      new SparkConf(loadDefaults = false),
+      blockManager,
+      new HashPartitioner(7),
+      shuffleWriteMetrics,
+      serializer
+    )
+    intercept[SparkException] {
+      writer.insertAll((0 until 100000).iterator.map(i => {
+        if (i == 99990) {
+          throw new SparkException("Intentional failure")
+        }
+        (i, i)
+      }))
+    }
+    assert(temporaryFilesCreated.nonEmpty)
+    writer.stop()
+    assert(temporaryFilesCreated.count(_.exists()) === 0)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
new file mode 100644
index 0000000..c6ada71
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleWriterSuite.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import org.mockito.Mockito._
+import org.scalatest.FunSuite
+
+import org.apache.spark.{Aggregator, SparkConf}
+
+class SortShuffleWriterSuite extends FunSuite {
+
+  import SortShuffleWriter._
+
+  test("conditions for bypassing merge-sort") {
+    val conf = new SparkConf(loadDefaults = false)
+    val agg = mock(classOf[Aggregator[_, _, _]], RETURNS_SMART_NULLS)
+    val ord = implicitly[Ordering[Int]]
+
+    // Numbers of partitions that are above and below the default bypassMergeThreshold
+    val FEW_PARTITIONS = 50
+    val MANY_PARTITIONS = 10000
+
+    // Shuffles with no ordering or aggregator: should bypass unless # of partitions is high
+    assert(shouldBypassMergeSort(conf, FEW_PARTITIONS, None, None))
+    assert(!shouldBypassMergeSort(conf, MANY_PARTITIONS, None, None))
+
+    // Shuffles with an ordering or aggregator: should not bypass even if they have few partitions
+    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, None, Some(ord)))
+    assert(!shouldBypassMergeSort(conf, FEW_PARTITIONS, Some(agg), None))
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index ad43a3e..7bdea72 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -18,14 +18,28 @@ package org.apache.spark.storage
 
 import java.io.File
 
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkConf
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.util.Utils
 
-class BlockObjectWriterSuite extends SparkFunSuite {
+class BlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  var tempDir: File = _
+
+  override def beforeEach(): Unit = {
+    tempDir = Utils.createTempDir()
+  }
+
+  override def afterEach(): Unit = {
+    Utils.deleteRecursively(tempDir)
+  }
+
   test("verify write metrics") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -47,7 +61,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
   }
 
   test("verify write metrics on revert") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -70,7 +84,7 @@ class BlockObjectWriterSuite extends SparkFunSuite {
   }
 
   test("Reopening a closed block writer") {
-    val file = new File(Utils.createTempDir(), "somefile")
+    val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
@@ -81,4 +95,79 @@ class BlockObjectWriterSuite extends SparkFunSuite {
       writer.open()
     }
   }
+
+  test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+  }
+
+  test("commitAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.commitAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    writer.commitAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 1000)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("revertPartialWritesAndClose() should be idempotent") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    writer.revertPartialWritesAndClose()
+    val bytesWritten = writeMetrics.shuffleBytesWritten
+    val writeTime = writeMetrics.shuffleWriteTime
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleRecordsWritten === 0)
+    assert(writeMetrics.shuffleBytesWritten === bytesWritten)
+    assert(writeMetrics.shuffleWriteTime === writeTime)
+  }
+
+  test("fileSegment() can only be called after commitAndClose() has been called") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    for (i <- 1 to 1000) {
+      writer.write(i, i)
+    }
+    intercept[IllegalStateException] {
+      writer.fileSegment()
+    }
+    writer.close()
+  }
+
+  test("commitAndClose() without ever opening or writing") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
+    writer.commitAndClose()
+    assert(writer.fileSegment().length === 0)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6430028/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index 9039dbe..7d7b41b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -23,10 +23,12 @@ import org.scalatest.PrivateMethodTester
 
 import scala.util.Random
 
+import org.scalatest.FunSuite
+
 import org.apache.spark._
 import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
 
-class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester {
+class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
   private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
     val conf = new SparkConf(loadDefaults)
     if (kryo) {
@@ -37,21 +39,12 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
       conf.set("spark.serializer.objectStreamReset", "1")
       conf.set("spark.serializer", classOf[JavaSerializer].getName)
     }
+    conf.set("spark.shuffle.sort.bypassMergeThreshold", "0")
     // Ensure that we actually have multiple batches per spill file
     conf.set("spark.shuffle.spill.batchSize", "10")
     conf
   }
 
-  private def assertBypassedMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
-    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
-    assert(sorter.invokePrivate(bypassMergeSort()), "sorter did not bypass merge-sort")
-  }
-
-  private def assertDidNotBypassMergeSort(sorter: ExternalSorter[_, _, _]): Unit = {
-    val bypassMergeSort = PrivateMethod[Boolean]('bypassMergeSort)
-    assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
-  }
-
   test("empty data stream with kryo ser") {
     emptyDataStream(createSparkConf(false, true))
   }
@@ -161,39 +154,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(7)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
-    sorter.insertAll(elements)
-    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
-    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
-    assert(iter.next() === (0, Nil))
-    assert(iter.next() === (1, List((1, 1))))
-    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
-    assert(iter.next() === (3, Nil))
-    assert(iter.next() === (4, Nil))
-    assert(iter.next() === (5, List((5, 5))))
-    assert(iter.next() === (6, Nil))
-    sorter.stop()
-  }
-
-  test("empty partitions with spilling, bypass merge-sort with kryo ser") {
-    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
-  }
-
-  test("empty partitions with spilling, bypass merge-sort with java ser") {
-    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
-  }
-
-  def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
-
-    val sorter = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(7)), None, None)
-    assertBypassedMergeSort(sorter)
     sorter.insertAll(elements)
     assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
     val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -376,7 +336,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
     sorter.insertAll((0 until 120000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     sorter.stop()
@@ -384,7 +343,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter2 = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter2)
     sorter2.insertAll((0 until 120000).iterator.map(i => (i, i)))
     assert(diskBlockManager.getAllFiles().length > 0)
     assert(sorter2.iterator.toSet === (0 until 120000).map(i => (i, i)).toSet)
@@ -392,29 +350,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     assert(diskBlockManager.getAllBlocks().length === 0)
   }
 
-  test("cleanup of intermediate files in sorter, bypass merge-sort") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter)
-    sorter.insertAll((0 until 100000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-
-    val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter2)
-    sorter2.insertAll((0 until 100000).iterator.map(i => (i, i)))
-    assert(diskBlockManager.getAllFiles().length > 0)
-    assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
-    sorter2.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-  }
-
   test("cleanup of intermediate files in sorter if there are errors") {
     val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -426,7 +361,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
 
     val sorter = new ExternalSorter[Int, Int, Int](
       None, Some(new HashPartitioner(3)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter)
     intercept[SparkException] {
       sorter.insertAll((0 until 120000).iterator.map(i => {
         if (i == 119990) {
@@ -440,28 +374,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     assert(diskBlockManager.getAllBlocks().length === 0)
   }
 
-  test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
-    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
-
-    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
-    assertBypassedMergeSort(sorter)
-    intercept[SparkException] {
-      sorter.insertAll((0 until 100000).iterator.map(i => {
-        if (i == 99990) {
-          throw new SparkException("Intentional failure")
-        }
-        (i, i)
-      }))
-    }
-    assert(diskBlockManager.getAllFiles().length > 0)
-    sorter.stop()
-    assert(diskBlockManager.getAllBlocks().length === 0)
-  }
-
   test("cleanup of intermediate files in shuffle") {
     val conf = createSparkConf(false, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
@@ -776,40 +688,6 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext with Priv
     }
   }
 
-  test("conditions for bypassing merge-sort") {
-    val conf = createSparkConf(false, false)
-    conf.set("spark.shuffle.memoryFraction", "0.001")
-    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
-    sc = new SparkContext("local", "test", conf)
-
-    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
-    val ord = implicitly[Ordering[Int]]
-
-    // Numbers of partitions that are above and below the default bypassMergeThreshold
-    val FEW_PARTITIONS = 50
-    val MANY_PARTITIONS = 10000
-
-    // Sorters with no ordering or aggregator: should bypass unless # of partitions is high
-
-    val sorter1 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
-    assertBypassedMergeSort(sorter1)
-
-    val sorter2 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(MANY_PARTITIONS)), None, None)
-    assertDidNotBypassMergeSort(sorter2)
-
-    // Sorters with an ordering or aggregator: should not bypass even if they have few partitions
-
-    val sorter3 = new ExternalSorter[Int, Int, Int](
-      None, Some(new HashPartitioner(FEW_PARTITIONS)), Some(ord), None)
-    assertDidNotBypassMergeSort(sorter3)
-
-    val sorter4 = new ExternalSorter[Int, Int, Int](
-      Some(agg), Some(new HashPartitioner(FEW_PARTITIONS)), None, None)
-    assertDidNotBypassMergeSort(sorter4)
-  }
-
   test("sort without breaking sorting contracts with kryo ser") {
     sortWithoutBreakingSortingContracts(createSparkConf(true, true))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org