You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by rx...@apache.org on 2014/07/31 03:08:06 UTC

[1/2] SPARK-2045 Sort-based shuffle

Repository: spark
Updated Branches:
  refs/heads/master da5017668 -> e96628440


http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index ad20f9b..4bc4346 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -19,9 +19,6 @@ package org.apache.spark
 
 import java.lang.ref.WeakReference
 
-import org.apache.spark.broadcast.Broadcast
-
-import scala.collection.mutable
 import scala.collection.mutable.{HashSet, SynchronizedSet}
 import scala.language.existentials
 import scala.language.postfixOps
@@ -34,15 +31,28 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.SparkContext._
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.{BlockId, BroadcastBlockId, RDDBlockId, ShuffleBlockId}
-
-class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
-
+import org.apache.spark.storage._
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.storage.BroadcastBlockId
+import org.apache.spark.storage.RDDBlockId
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.storage.ShuffleIndexBlockId
+
+/**
+ * An abstract base class for context cleaner tests, which sets up a context with a config
+ * suitable for cleaner tests and provides some utility functions. Subclasses can use different
+ * config options, in particular, a different shuffle manager class
+ */
+abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager])
+  extends FunSuite with BeforeAndAfter with LocalSparkContext
+{
   implicit val defaultTimeout = timeout(10000 millis)
   val conf = new SparkConf()
     .setMaster("local[2]")
     .setAppName("ContextCleanerSuite")
     .set("spark.cleaner.referenceTracking.blocking", "true")
+    .set("spark.shuffle.manager", shuffleManager.getName)
 
   before {
     sc = new SparkContext(conf)
@@ -55,6 +65,59 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     }
   }
 
+  //------ Helper functions ------
+
+  protected def newRDD() = sc.makeRDD(1 to 10)
+  protected def newPairRDD() = newRDD().map(_ -> 1)
+  protected def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+  protected def newBroadcast() = sc.broadcast(1 to 100)
+
+  protected def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+    def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
+      rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
+        getAllDependencies(dep.rdd)
+      }
+    }
+    val rdd = newShuffleRDD()
+
+    // Get all the shuffle dependencies
+    val shuffleDeps = getAllDependencies(rdd)
+      .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+      .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
+    (rdd, shuffleDeps)
+  }
+
+  protected def randomRdd() = {
+    val rdd: RDD[_] = Random.nextInt(3) match {
+      case 0 => newRDD()
+      case 1 => newShuffleRDD()
+      case 2 => newPairRDD.join(newPairRDD())
+    }
+    if (Random.nextBoolean()) rdd.persist()
+    rdd.count()
+    rdd
+  }
+
+  /** Run GC and make sure it actually has run */
+  protected def runGC() {
+    val weakRef = new WeakReference(new Object())
+    val startTime = System.currentTimeMillis
+    System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
+    // Wait until a weak reference object has been GCed
+    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+      System.gc()
+      Thread.sleep(200)
+    }
+  }
+
+  protected def cleaner = sc.cleaner.get
+}
+
+
+/**
+ * Basic ContextCleanerSuite, which uses sort-based shuffle
+ */
+class ContextCleanerSuite extends ContextCleanerSuiteBase {
   test("cleanup RDD") {
     val rdd = newRDD().persist()
     val collected = rdd.collect().toList
@@ -147,7 +210,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     val numRdds = 100
     val numBroadcasts = 4 // Broadcasts are more costly
     val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = broadcastBuffer.map(_.id)
@@ -180,12 +243,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
       .setMaster("local-cluster[2, 1, 512]")
       .setAppName("ContextCleanerSuite")
       .set("spark.cleaner.referenceTracking.blocking", "true")
+      .set("spark.shuffle.manager", shuffleManager.getName)
     sc = new SparkContext(conf2)
 
     val numRdds = 10
     val numBroadcasts = 4 // Broadcasts are more costly
     val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = broadcastBuffer.map(_.id)
@@ -210,57 +274,82 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
       case _ => false
     }, askSlaves = true).isEmpty)
   }
+}
 
-  //------ Helper functions ------
 
-  private def newRDD() = sc.makeRDD(1 to 10)
-  private def newPairRDD() = newRDD().map(_ -> 1)
-  private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
-  private def newBroadcast() = sc.broadcast(1 to 100)
+/**
+ * A copy of the shuffle tests for sort-based shuffle
+ */
+class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) {
+  test("cleanup shuffle") {
+    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
+    val collected = rdd.collect().toList
+    val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
 
-  private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
-    def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
-      rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
-        getAllDependencies(dep.rdd)
-      }
-    }
-    val rdd = newShuffleRDD()
+    // Explicit cleanup
+    shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true))
+    tester.assertCleanup()
 
-    // Get all the shuffle dependencies
-    val shuffleDeps = getAllDependencies(rdd)
-      .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
-      .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
-    (rdd, shuffleDeps)
+    // Verify that shuffles can be re-executed after cleaning up
+    assert(rdd.collect().toList.equals(collected))
   }
 
-  private def randomRdd() = {
-    val rdd: RDD[_] = Random.nextInt(3) match {
-      case 0 => newRDD()
-      case 1 => newShuffleRDD()
-      case 2 => newPairRDD.join(newPairRDD())
-    }
-    if (Random.nextBoolean()) rdd.persist()
+  test("automatically cleanup shuffle") {
+    var rdd = newShuffleRDD()
     rdd.count()
-    rdd
-  }
 
-  private def randomBroadcast() = {
-    sc.broadcast(Random.nextInt(Int.MaxValue))
+    // Test that GC does not cause shuffle cleanup due to a strong reference
+    val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
+    }
+
+    // Test that GC causes shuffle cleanup after dereferencing the RDD
+    val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0))
+    rdd = null  // Make RDD out of scope, so that corresponding shuffle goes out of scope
+    runGC()
+    postGCTester.assertCleanup()
   }
 
-  /** Run GC and make sure it actually has run */
-  private def runGC() {
-    val weakRef = new WeakReference(new Object())
-    val startTime = System.currentTimeMillis
-    System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
-    // Wait until a weak reference object has been GCed
-    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
-      System.gc()
-      Thread.sleep(200)
+  test("automatically cleanup RDD + shuffle + broadcast in distributed mode") {
+    sc.stop()
+
+    val conf2 = new SparkConf()
+      .setMaster("local-cluster[2, 1, 512]")
+      .setAppName("ContextCleanerSuite")
+      .set("spark.cleaner.referenceTracking.blocking", "true")
+      .set("spark.shuffle.manager", shuffleManager.getName)
+    sc = new SparkContext(conf2)
+
+    val numRdds = 10
+    val numBroadcasts = 4 // Broadcasts are more costly
+    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer
+    val rddIds = sc.persistentRdds.keys.toSeq
+    val shuffleIds = 0 until sc.newShuffleId()
+    val broadcastIds = broadcastBuffer.map(_.id)
+
+    val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    runGC()
+    intercept[Exception] {
+      preGCTester.assertCleanup()(timeout(1000 millis))
     }
-  }
 
-  private def cleaner = sc.cleaner.get
+    // Test that GC triggers the cleanup of all variables after the dereferencing them
+    val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds)
+    broadcastBuffer.clear()
+    rddBuffer.clear()
+    runGC()
+    postGCTester.assertCleanup()
+
+    // Make sure the broadcasted task closure no longer exists after GC.
+    val taskClosureBroadcastId = broadcastIds.max + 1
+    assert(sc.env.blockManager.master.getMatchingBlockIds({
+      case BroadcastBlockId(`taskClosureBroadcastId`, _) => true
+      case _ => false
+    }, askSlaves = true).isEmpty)
+  }
 }
 
 
@@ -418,6 +507,7 @@ class CleanerTester(
   private def getShuffleBlocks(shuffleId: Int): Seq[BlockId] = {
     blockManager.master.getMatchingBlockIds( _ match {
       case ShuffleBlockId(`shuffleId`, _, _) => true
+      case ShuffleIndexBlockId(`shuffleId`, _, _) => true
       case _ => false
     }, askSlaves = true)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index 47df000..d7b2d2e 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -28,6 +28,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
   }
 
   override def afterAll() {
-    System.setProperty("spark.shuffle.use.netty", "false")
+    System.clearProperty("spark.shuffle.use.netty")
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index eae67c7..b13ddf9 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -58,8 +58,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     // default Java serializer cannot handle the non serializable class.
     val c = new ShuffledRDD[Int,
       NonJavaSerializableClass,
-      NonJavaSerializableClass,
-      (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS))
+      NonJavaSerializableClass](b, new HashPartitioner(NUM_BLOCKS))
     c.setSerializer(new KryoSerializer(conf))
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
 
@@ -83,8 +82,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     // default Java serializer cannot handle the non serializable class.
     val c = new ShuffledRDD[Int,
       NonJavaSerializableClass,
-      NonJavaSerializableClass,
-      (Int, NonJavaSerializableClass)](b, new HashPartitioner(3))
+      NonJavaSerializableClass](b, new HashPartitioner(3))
     c.setSerializer(new KryoSerializer(conf))
     assert(c.count === 10)
   }
@@ -100,7 +98,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
 
     // NOTE: The default Java serializer doesn't create zero-sized blocks.
     //       So, use Kryo
-    val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+    val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
       .setSerializer(new KryoSerializer(conf))
 
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -126,7 +124,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     val b = a.map(x => (x, x*2))
 
     // NOTE: The default Java serializer should create zero-sized blocks
-    val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
+    val c = new ShuffledRDD[Int, Int, Int](b, new HashPartitioner(10))
 
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
     assert(c.count === 4)
@@ -141,19 +139,19 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     assert(nonEmptyBlocks.size <= 4)
   }
 
-  test("shuffle using mutable pairs") {
+  test("shuffle on mutable pairs") {
     // Use a local cluster with 2 processes to make sure there are both local and remote blocks
     sc = new SparkContext("local-cluster[2,1,512]", "test")
     def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
     val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
     val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
-    val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs,
+    val results = new ShuffledRDD[Int, Int, Int](pairs,
       new HashPartitioner(2)).collect()
 
-    data.foreach { pair => results should contain (pair) }
+    data.foreach { pair => results should contain ((pair._1, pair._2)) }
   }
 
-  test("sorting using mutable pairs") {
+  test("sorting on mutable pairs") {
     // This is not in SortingSuite because of the local cluster setup.
     // Use a local cluster with 2 processes to make sure there are both local and remote blocks
     sc = new SparkContext("local-cluster[2,1,512]", "test")
@@ -162,10 +160,10 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
     val results = new OrderedRDDFunctions[Int, Int, MutablePair[Int, Int]](pairs)
       .sortByKey().collect()
-    results(0) should be (p(1, 11))
-    results(1) should be (p(2, 22))
-    results(2) should be (p(3, 33))
-    results(3) should be (p(100, 100))
+    results(0) should be ((1, 11))
+    results(1) should be ((2, 22))
+    results(2) should be ((3, 33))
+    results(3) should be ((100, 100))
   }
 
   test("cogroup using mutable pairs") {

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
new file mode 100644
index 0000000..5c02c00
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.BeforeAndAfterAll
+
+class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
+
+  // This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+
+  override def beforeAll() {
+    System.setProperty("spark.shuffle.manager",
+      "org.apache.spark.shuffle.sort.SortShuffleManager")
+  }
+
+  override def afterAll() {
+    System.clearProperty("spark.shuffle.manager")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 4953d56..8966eed 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -270,7 +270,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     // we can optionally shuffle to keep the upstream parallel
     val coalesced5 = data.coalesce(1, shuffle = true)
     val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd.
-      asInstanceOf[ShuffledRDD[_, _, _, _]] != null
+      asInstanceOf[ShuffledRDD[_, _, _]] != null
     assert(isEquals)
 
     // when shuffling, we can increase the number of partitions
@@ -730,9 +730,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
 
     // Any ancestors before the shuffle are not considered
     assert(ancestors4.size === 0)
-    assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0)
+    assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 0)
     assert(ancestors5.size === 3)
-    assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1)
+    assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
     assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0)
     assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
index 0b7ad18..7de5df6 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala
@@ -208,11 +208,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
     val resultA = rddA.reduceByKey(math.max).collect()
     assert(resultA.length == 50000)
     resultA.foreach { case(k, v) =>
-      k match {
-        case 0 => assert(v == 1)
-        case 25000 => assert(v == 50001)
-        case 49999 => assert(v == 99999)
-        case _ =>
+      if (v != k * 2 + 1) {
+        fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
       }
     }
 
@@ -221,11 +218,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
     val resultB = rddB.groupByKey().collect()
     assert(resultB.length == 25000)
     resultB.foreach { case(i, seq) =>
-      i match {
-        case 0 => assert(seq.toSet == Set[Int](0, 1, 2, 3))
-        case 12500 => assert(seq.toSet == Set[Int](50000, 50001, 50002, 50003))
-        case 24999 => assert(seq.toSet == Set[Int](99996, 99997, 99998, 99999))
-        case _ =>
+      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+      if (seq.toSet != expected) {
+        fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
       }
     }
 
@@ -239,6 +234,9 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
         case 0 =>
           assert(seq1.toSet == Set[Int](0))
           assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+        case 1 =>
+          assert(seq1.toSet == Set[Int](1))
+          assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
         case 5000 =>
           assert(seq1.toSet == Set[Int](5000))
           assert(seq2.toSet == Set[Int]())
@@ -369,10 +367,3 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext {
   }
 
 }
-
-/**
- * A dummy class that always returns the same hash code, to easily test hash collisions
- */
-case class FixedHashObject(v: Int, h: Int) extends Serializable {
-  override def hashCode(): Int = h
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
new file mode 100644
index 0000000..ddb5df4
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -0,0 +1,566 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark._
+import org.apache.spark.SparkContext._
+
+class ExternalSorterSuite extends FunSuite with LocalSparkContext {
+  test("empty data stream") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+
+    // Both aggregator and ordering
+    val sorter = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+    assert(sorter.iterator.toSeq === Seq())
+    sorter.stop()
+
+    // Only aggregator
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(3)), None, None)
+    assert(sorter2.iterator.toSeq === Seq())
+    sorter2.stop()
+
+    // Only ordering
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    assert(sorter3.iterator.toSeq === Seq())
+    sorter3.stop()
+
+    // Neither aggregator nor ordering
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), None, None)
+    assert(sorter4.iterator.toSeq === Seq())
+    sorter4.stop()
+  }
+
+  test("few elements per partition") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+    val elements = Set((1, 1), (2, 2), (5, 5))
+    val expected = Set(
+      (0, Set()), (1, Set((1, 1))), (2, Set((2, 2))), (3, Set()), (4, Set()),
+      (5, Set((5, 5))), (6, Set()))
+
+    // Both aggregator and ordering
+    val sorter = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+    sorter.write(elements.iterator)
+    assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter.stop()
+
+    // Only aggregator
+    val sorter2 = new ExternalSorter[Int, Int, Int](
+      Some(agg), Some(new HashPartitioner(7)), None, None)
+    sorter2.write(elements.iterator)
+    assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter2.stop()
+
+    // Only ordering
+    val sorter3 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), Some(ord), None)
+    sorter3.write(elements.iterator)
+    assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter3.stop()
+
+    // Neither aggregator nor ordering
+    val sorter4 = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), None, None)
+    sorter4.write(elements.iterator)
+    assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
+    sorter4.stop()
+  }
+
+  test("empty partitions with spilling") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+    val elements = Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2))
+
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(7)), None, None)
+    sorter.write(elements)
+    assert(sc.env.blockManager.diskBlockManager.getAllFiles().length > 0) // Make sure it spilled
+    val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
+    assert(iter.next() === (0, Nil))
+    assert(iter.next() === (1, List((1, 1))))
+    assert(iter.next() === (2, (0 until 100000).map(x => (2, 2)).toList))
+    assert(iter.next() === (3, Nil))
+    assert(iter.next() === (4, Nil))
+    assert(iter.next() === (5, List((5, 5))))
+    assert(iter.next() === (6, Nil))
+    sorter.stop()
+  }
+
+  test("spilling in local cluster") {
+    val conf = new SparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+    // reduceByKey - should spill ~8 times
+    val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+    val resultA = rddA.reduceByKey(math.max).collect()
+    assert(resultA.length == 50000)
+    resultA.foreach { case(k, v) =>
+      if (v != k * 2 + 1) {
+        fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+      }
+    }
+
+    // groupByKey - should spill ~17 times
+    val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+    val resultB = rddB.groupByKey().collect()
+    assert(resultB.length == 25000)
+    resultB.foreach { case(i, seq) =>
+      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+      if (seq.toSet != expected) {
+        fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+      }
+    }
+
+    // cogroup - should spill ~7 times
+    val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+    val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+    val resultC = rddC1.cogroup(rddC2).collect()
+    assert(resultC.length == 10000)
+    resultC.foreach { case(i, (seq1, seq2)) =>
+      i match {
+        case 0 =>
+          assert(seq1.toSet == Set[Int](0))
+          assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+        case 1 =>
+          assert(seq1.toSet == Set[Int](1))
+          assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+        case 5000 =>
+          assert(seq1.toSet == Set[Int](5000))
+          assert(seq2.toSet == Set[Int]())
+        case 9999 =>
+          assert(seq1.toSet == Set[Int](9999))
+          assert(seq2.toSet == Set[Int]())
+        case _ =>
+      }
+    }
+
+    // larger cogroup - should spill ~7 times
+    val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+    val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+    val resultD = rddD1.cogroup(rddD2).collect()
+    assert(resultD.length == 5000)
+    resultD.foreach { case(i, (seq1, seq2)) =>
+      val expected = Set(i * 2, i * 2 + 1)
+      if (seq1.toSet != expected) {
+        fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+      }
+      if (seq2.toSet != expected) {
+        fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+      }
+    }
+  }
+
+  test("spilling in local cluster with many reduce tasks") {
+    val conf = new SparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+
+    // reduceByKey - should spill ~4 times per executor
+    val rddA = sc.parallelize(0 until 100000).map(i => (i/2, i))
+    val resultA = rddA.reduceByKey(math.max _, 100).collect()
+    assert(resultA.length == 50000)
+    resultA.foreach { case(k, v) =>
+      if (v != k * 2 + 1) {
+        fail(s"Value for ${k} was wrong: expected ${k * 2 + 1}, got ${v}")
+      }
+    }
+
+    // groupByKey - should spill ~8 times per executor
+    val rddB = sc.parallelize(0 until 100000).map(i => (i/4, i))
+    val resultB = rddB.groupByKey(100).collect()
+    assert(resultB.length == 25000)
+    resultB.foreach { case(i, seq) =>
+      val expected = Set(i * 4, i * 4 + 1, i * 4 + 2, i * 4 + 3)
+      if (seq.toSet != expected) {
+        fail(s"Value for ${i} was wrong: expected ${expected}, got ${seq.toSet}")
+      }
+    }
+
+    // cogroup - should spill ~4 times per executor
+    val rddC1 = sc.parallelize(0 until 10000).map(i => (i, i))
+    val rddC2 = sc.parallelize(0 until 10000).map(i => (i%1000, i))
+    val resultC = rddC1.cogroup(rddC2, 100).collect()
+    assert(resultC.length == 10000)
+    resultC.foreach { case(i, (seq1, seq2)) =>
+      i match {
+        case 0 =>
+          assert(seq1.toSet == Set[Int](0))
+          assert(seq2.toSet == Set[Int](0, 1000, 2000, 3000, 4000, 5000, 6000, 7000, 8000, 9000))
+        case 1 =>
+          assert(seq1.toSet == Set[Int](1))
+          assert(seq2.toSet == Set[Int](1, 1001, 2001, 3001, 4001, 5001, 6001, 7001, 8001, 9001))
+        case 5000 =>
+          assert(seq1.toSet == Set[Int](5000))
+          assert(seq2.toSet == Set[Int]())
+        case 9999 =>
+          assert(seq1.toSet == Set[Int](9999))
+          assert(seq2.toSet == Set[Int]())
+        case _ =>
+      }
+    }
+
+    // larger cogroup - should spill ~4 times per executor
+    val rddD1 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+    val rddD2 = sc.parallelize(0 until 10000).map(i => (i/2, i))
+    val resultD = rddD1.cogroup(rddD2).collect()
+    assert(resultD.length == 5000)
+    resultD.foreach { case(i, (seq1, seq2)) =>
+      val expected = Set(i * 2, i * 2 + 1)
+      if (seq1.toSet != expected) {
+        fail(s"Value 1 for ${i} was wrong: expected ${expected}, got ${seq1.toSet}")
+      }
+      if (seq2.toSet != expected) {
+        fail(s"Value 2 for ${i} was wrong: expected ${expected}, got ${seq2.toSet}")
+      }
+    }
+  }
+
+  test("cleanup of intermediate files in sorter") {
+    val conf = new SparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+
+    val sorter2 = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+    sorter2.write((0 until 100000).iterator.map(i => (i, i)))
+    assert(diskBlockManager.getAllFiles().length > 0)
+    assert(sorter2.iterator.toSet === (0 until 100000).map(i => (i, i)).toSet)
+    sorter2.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in sorter if there are errors") {
+    val conf = new SparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+    intercept[SparkException] {
+      sorter.write((0 until 100000).iterator.map(i => {
+        if (i == 99990) {
+          throw new SparkException("Intentional failure")
+        }
+        (i, i)
+      }))
+    }
+    assert(diskBlockManager.getAllFiles().length > 0)
+    sorter.stop()
+    assert(diskBlockManager.getAllBlocks().length === 0)
+  }
+
+  test("cleanup of intermediate files in shuffle") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+    val data = sc.parallelize(0 until 100000, 2).map(i => (i, i))
+    assert(data.reduceByKey(_ + _).count() === 100000)
+
+    // After the shuffle, there should be only 4 files on disk: our two map output files and
+    // their index files. All other intermediate files should've been deleted.
+    assert(diskBlockManager.getAllFiles().length === 4)
+  }
+
+  test("cleanup of intermediate files in shuffle with errors") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+    val diskBlockManager = SparkEnv.get.blockManager.diskBlockManager
+
+    val data = sc.parallelize(0 until 100000, 2).map(i => {
+      if (i == 99990) {
+        throw new Exception("Intentional failure")
+      }
+      (i, i)
+    })
+    intercept[SparkException] {
+      data.reduceByKey(_ + _).count()
+    }
+
+    // After the shuffle, there should be only 2 files on disk: the output of task 1 and its index.
+    // All other files (map 2's output and intermediate merge files) should've been deleted.
+    assert(diskBlockManager.getAllFiles().length === 2)
+  }
+
+  test("no partial aggregation or sorting") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val sorter = new ExternalSorter[Int, Int, Int](None, Some(new HashPartitioner(3)), None, None)
+    sorter.write((0 until 100000).iterator.map(i => (i / 4, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 100000).map(i => (i / 4, i)).filter(_._1 % 3 == p).toSet)
+    }).toSet
+    assert(results === expected)
+  }
+
+  test("partial aggregation without spill") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+    sorter.write((0 until 100).iterator.map(i => (i / 2, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 50).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+    }).toSet
+    assert(results === expected)
+  }
+
+  test("partial aggregation with spill, no ordering") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), None, None)
+    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+    }).toSet
+    assert(results === expected)
+  }
+
+  test("partial aggregation with spill, with ordering") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val agg = new Aggregator[Int, Int, Int](i => i, (i, j) => i + j, (i, j) => i + j)
+    val ord = implicitly[Ordering[Int]]
+    val sorter = new ExternalSorter(Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+    sorter.write((0 until 100000).iterator.map(i => (i / 2, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSet)}.toSet
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 50000).map(i => (i, i * 4 + 1)).filter(_._1 % 3 == p).toSet)
+    }).toSet
+    assert(results === expected)
+  }
+
+  test("sorting without aggregation, no spill") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val ord = implicitly[Ordering[Int]]
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    sorter.write((0 until 100).iterator.map(i => (i, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 100).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+    }).toSeq
+    assert(results === expected)
+  }
+
+  test("sorting without aggregation, with spill") {
+    val conf = new SparkConf(false)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
+    sc = new SparkContext("local", "test", conf)
+
+    val ord = implicitly[Ordering[Int]]
+    val sorter = new ExternalSorter[Int, Int, Int](
+      None, Some(new HashPartitioner(3)), Some(ord), None)
+    sorter.write((0 until 100000).iterator.map(i => (i, i)))
+    val results = sorter.partitionedIterator.map{case (p, vs) => (p, vs.toSeq)}.toSeq
+    val expected = (0 until 3).map(p => {
+      (p, (0 until 100000).map(i => (i, i)).filter(_._1 % 3 == p).toSeq)
+    }).toSeq
+    assert(results === expected)
+  }
+
+  test("spilling with hash collisions") {
+    val conf = new SparkConf(true)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+    def createCombiner(i: String) = ArrayBuffer[String](i)
+    def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+    def mergeCombiners(buffer1: ArrayBuffer[String], buffer2: ArrayBuffer[String]) =
+      buffer1 ++= buffer2
+
+    val agg = new Aggregator[String, String, ArrayBuffer[String]](
+      createCombiner _, mergeValue _, mergeCombiners _)
+
+    val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+      Some(agg), None, None, None)
+
+    val collisionPairs = Seq(
+      ("Aa", "BB"),                   // 2112
+      ("to", "v1"),                   // 3707
+      ("variants", "gelato"),         // -1249574770
+      ("Teheran", "Siblings"),        // 231609873
+      ("misused", "horsemints"),      // 1069518484
+      ("isohel", "epistolaries"),     // -1179291542
+      ("righto", "buzzards"),         // -931102253
+      ("hierarch", "crinolines"),     // -1732884796
+      ("inwork", "hypercatalexes"),   // -1183663690
+      ("wainages", "presentencing"),  // 240183619
+      ("trichothecenes", "locular"),  // 339006536
+      ("pomatoes", "eructation")      // 568647356
+    )
+
+    collisionPairs.foreach { case (w1, w2) =>
+      // String.hashCode is documented to use a specific algorithm, but check just in case
+      assert(w1.hashCode === w2.hashCode)
+    }
+
+    val toInsert = (1 to 100000).iterator.map(_.toString).map(s => (s, s)) ++
+      collisionPairs.iterator ++ collisionPairs.iterator.map(_.swap)
+
+    sorter.write(toInsert)
+
+    // A map of collision pairs in both directions
+    val collisionPairsMap = (collisionPairs ++ collisionPairs.map(_.swap)).toMap
+
+    // Avoid map.size or map.iterator.length because this destructively sorts the underlying map
+    var count = 0
+
+    val it = sorter.iterator
+    while (it.hasNext) {
+      val kv = it.next()
+      val expectedValue = ArrayBuffer[String](collisionPairsMap.getOrElse(kv._1, kv._1))
+      assert(kv._2.equals(expectedValue))
+      count += 1
+    }
+    assert(count === 100000 + collisionPairs.size * 2)
+  }
+
+  test("spilling with many hash collisions") {
+    val conf = new SparkConf(true)
+    conf.set("spark.shuffle.memoryFraction", "0.0001")
+    sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+    val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
+    val sorter = new ExternalSorter[FixedHashObject, Int, Int](Some(agg), None, None, None)
+
+    // Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
+    // problems if the map fails to group together the objects with the same code (SPARK-2043).
+    val toInsert = for (i <- 1 to 10; j <- 1 to 10000) yield (FixedHashObject(j, j % 2), 1)
+    sorter.write(toInsert.iterator)
+
+    val it = sorter.iterator
+    var count = 0
+    while (it.hasNext) {
+      val kv = it.next()
+      assert(kv._2 === 10)
+      count += 1
+    }
+    assert(count === 10000)
+  }
+
+  test("spilling with hash collisions using the Int.MaxValue key") {
+    val conf = new SparkConf(true)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+    def createCombiner(i: Int) = ArrayBuffer[Int](i)
+    def mergeValue(buffer: ArrayBuffer[Int], i: Int) = buffer += i
+    def mergeCombiners(buf1: ArrayBuffer[Int], buf2: ArrayBuffer[Int]) = buf1 ++= buf2
+
+    val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
+    val sorter = new ExternalSorter[Int, Int, ArrayBuffer[Int]](Some(agg), None, None, None)
+
+    sorter.write((1 to 100000).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
+
+    val it = sorter.iterator
+    while (it.hasNext) {
+      // Should not throw NoSuchElementException
+      it.next()
+    }
+  }
+
+  test("spilling with null keys and values") {
+    val conf = new SparkConf(true)
+    conf.set("spark.shuffle.memoryFraction", "0.001")
+    sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
+
+    def createCombiner(i: String) = ArrayBuffer[String](i)
+    def mergeValue(buffer: ArrayBuffer[String], i: String) = buffer += i
+    def mergeCombiners(buf1: ArrayBuffer[String], buf2: ArrayBuffer[String]) = buf1 ++= buf2
+
+    val agg = new Aggregator[String, String, ArrayBuffer[String]](
+      createCombiner, mergeValue, mergeCombiners)
+
+    val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
+      Some(agg), None, None, None)
+
+    sorter.write((1 to 100000).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
+      (null.asInstanceOf[String], "1"),
+      ("1", null.asInstanceOf[String]),
+      (null.asInstanceOf[String], null.asInstanceOf[String])
+    ))
+
+    val it = sorter.iterator
+    while (it.hasNext) {
+      // Should not throw NullPointerException
+      it.next()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
new file mode 100644
index 0000000..c787b5f
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/FixedHashObject.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+/**
+ * A dummy class that always returns the same hash code, to easily test hash collisions
+ */
+case class FixedHashObject(v: Int, h: Int) extends Serializable {
+  override def hashCode(): Int = h
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index 5318b8d..714f3b8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rdd.{ShuffledRDD, RDD}
 private[graphx]
 class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
   def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
-    val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner)
+    val rdd = new ShuffledRDD[VertexId, VD, VD](self, partitioner)
 
     // Set a custom serializer if the data is of int or double type.
     if (classTag[VD] == ClassTag.Int) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index a565d3b..b274859 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -33,7 +33,7 @@ private[graphx]
 class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
   /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
   def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
-    new ShuffledRDD[VertexId, Int, Int, RoutingTableMessage](
+    new ShuffledRDD[VertexId, Int, Int](
       self, partitioner).setSerializer(new RoutingTableMessageSerializer)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/project/SparkBuild.scala
----------------------------------------------------------------------
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 672343f..a8bbd55 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -295,6 +295,7 @@ object Unidoc {
         .map(_.filterNot(_.getCanonicalPath.contains("akka")))
         .map(_.filterNot(_.getCanonicalPath.contains("deploy")))
         .map(_.filterNot(_.getCanonicalPath.contains("network")))
+        .map(_.filterNot(_.getCanonicalPath.contains("shuffle")))
         .map(_.filterNot(_.getCanonicalPath.contains("executor")))
         .map(_.filterNot(_.getCanonicalPath.contains("python")))
         .map(_.filterNot(_.getCanonicalPath.contains("collection")))

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 392a7f3..30712f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -49,7 +49,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(r => mutablePair.update(hashExpressions(r), r))
         }
         val part = new HashPartitioner(numPartitions)
-        val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part)
+        val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
         shuffled.map(_._2)
 
@@ -62,7 +62,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(row => mutablePair.update(row, null))
         }
         val part = new RangePartitioner(numPartitions, rdd, ascending = true)
-        val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part)
+        val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
 
         shuffled.map(_._1)
@@ -73,7 +73,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(r => mutablePair.update(null, r))
         }
         val partitioner = new HashPartitioner(1)
-        val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner)
+        val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
         shuffled.map(_._2)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 174eda8..0027f3c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -148,7 +148,7 @@ case class Limit(limit: Int, child: SparkPlan)
       iter.take(limit).map(row => mutablePair.update(false, row))
     }
     val part = new HashPartitioner(1)
-    val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part)
+    val shuffled = new ShuffledRDD[Boolean, Row, Row](rdd, part)
     shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
     shuffled.mapPartitions(_.take(limit).map(_._2))
   }


[2/2] git commit: SPARK-2045 Sort-based shuffle

Posted by rx...@apache.org.
SPARK-2045 Sort-based shuffle

This adds a new ShuffleManager based on sorting, as described in https://issues.apache.org/jira/browse/SPARK-2045. The bulk of the code is in an ExternalSorter class that is similar to ExternalAppendOnlyMap, but sorts key-value pairs by partition ID and can be used to create a single sorted file with a map task's output. (Longer-term I think this can take on the remaining functionality in ExternalAppendOnlyMap and replace it so we don't have code duplication.)

The main TODOs still left are:
- [x] enabling ExternalSorter to merge across spilled files
  - [x] with an Ordering
  - [x] without an Ordering, using the keys' hash codes
- [x] adding more tests (e.g. a version of our shuffle suite that runs on this)
- [x] rebasing on top of the size-tracking refactoring in #1165 when that is merged
- [x] disabling spilling if spark.shuffle.spill is set to false

Despite this though, this seems to work pretty well (running successfully in cases where the hash shuffle would OOM, such as 1000 reduce tasks on executors with only 1G memory), and it seems to be comparable in speed or faster than hash-based shuffle (it will create much fewer files for the OS to keep track of). So I'm posting it to get some early feedback.

After these TODOs are done, I'd also like to enable ExternalSorter to sort data within each partition by a key as well, which will allow us to use it to implement external spilling in reduce tasks in `sortByKey`.

Author: Matei Zaharia <ma...@databricks.com>

Closes #1499 from mateiz/sort-based-shuffle and squashes the following commits:

bd841f9 [Matei Zaharia] Various review comments
d1c137fd [Matei Zaharia] Various review comments
a611159 [Matei Zaharia] Compile fixes due to rebase
62c56c8 [Matei Zaharia] Fix ShuffledRDD sometimes not returning Tuple2s.
f617432 [Matei Zaharia] Fix a failing test (seems to be due to change in SizeTracker logic)
9464d5f [Matei Zaharia] Simplify code and fix conflicts after latest rebase
0174149 [Matei Zaharia] Add cleanup behavior and cleanup tests for sort-based shuffle
eb4ee0d [Matei Zaharia] Remove customizable element type in ShuffledRDD
fa2e8db [Matei Zaharia] Allow nextBatchStream to be called after we're done looking at all streams
a34b352 [Matei Zaharia] Fix tracking of indices within a partition in SpillReader, and add test
03e1006 [Matei Zaharia] Add a SortShuffleSuite that runs ShuffleSuite with sort-based shuffle
3c7ff1f [Matei Zaharia] Obey the spark.shuffle.spill setting in ExternalSorter
ad65fbd [Matei Zaharia] Rebase on top of Aaron's Sorter change, and use Sorter in our buffer
44d2a93 [Matei Zaharia] Use estimateSize instead of atGrowThreshold to test collection sizes
5686f71 [Matei Zaharia] Optimize merging phase for in-memory only data:
5461cbb [Matei Zaharia] Review comments and more tests (e.g. tests with 1 element per partition)
e9ad356 [Matei Zaharia] Update ContextCleanerSuite to make sure shuffle cleanup tests use hash shuffle (since they were written for it)
c72362a [Matei Zaharia] Added bug fix and test for when iterators are empty
de1fb40 [Matei Zaharia] Make trait SizeTrackingCollection private[spark]
4988d16 [Matei Zaharia] tweak
c1b7572 [Matei Zaharia] Small optimization
ba7db7f [Matei Zaharia] Handle null keys in hash-based comparator, and add tests for collisions
ef4e397 [Matei Zaharia] Support for partial aggregation even without an Ordering
4b7a5ce [Matei Zaharia] More tests, and ability to sort data if a total ordering is given
e1f84be [Matei Zaharia] Fix disk block manager test
5a40a1c [Matei Zaharia] More tests
614f1b4 [Matei Zaharia] Add spill metrics to map tasks
cc52caf [Matei Zaharia] Add more error handling and tests for error cases
bbf359d [Matei Zaharia] More work
3a56341 [Matei Zaharia] More partial work towards sort-based shuffle
7a0895d [Matei Zaharia] Some more partial work towards sort-based shuffle
b615476 [Matei Zaharia] Scaffolding for sort-based shuffle


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e9662844
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e9662844
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e9662844

Branch: refs/heads/master
Commit: e966284409f9355e1169960e73a2215617c8cb22
Parents: da50176
Author: Matei Zaharia <ma...@databricks.com>
Authored: Wed Jul 30 18:07:59 2014 -0700
Committer: Reynold Xin <rx...@apache.org>
Committed: Wed Jul 30 18:07:59 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Aggregator.scala     |  24 +-
 .../scala/org/apache/spark/SparkContext.scala   |   8 +-
 .../org/apache/spark/api/java/JavaPairRDD.scala |   2 +-
 .../org/apache/spark/rdd/CoGroupedRDD.scala     |   7 +-
 .../apache/spark/rdd/OrderedRDDFunctions.scala  |  14 +-
 .../org/apache/spark/rdd/PairRDDFunctions.scala |   4 +-
 .../main/scala/org/apache/spark/rdd/RDD.scala   |   8 +-
 .../org/apache/spark/rdd/ShuffledRDD.scala      |  17 +-
 .../spark/shuffle/hash/HashShuffleManager.scala |   2 +-
 .../spark/shuffle/hash/HashShuffleReader.scala  |   5 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   6 +-
 .../spark/shuffle/sort/SortShuffleManager.scala |  80 +++
 .../spark/shuffle/sort/SortShuffleWriter.scala  | 165 +++++
 .../org/apache/spark/storage/BlockId.scala      |  11 +-
 .../apache/spark/storage/DiskBlockManager.scala |  38 +-
 .../spark/storage/ShuffleBlockManager.scala     |  29 +-
 .../util/collection/ExternalAppendOnlyMap.scala |  36 +-
 .../spark/util/collection/ExternalSorter.scala  | 662 +++++++++++++++++++
 .../collection/SizeTrackingAppendOnlyMap.scala  |   5 +-
 .../collection/SizeTrackingPairBuffer.scala     |  86 +++
 .../collection/SizeTrackingPairCollection.scala |  34 +
 .../org/apache/spark/CheckpointSuite.scala      |   2 +-
 .../org/apache/spark/ContextCleanerSuite.scala  | 186 ++++--
 .../org/apache/spark/ShuffleNettySuite.scala    |   2 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   |  26 +-
 .../org/apache/spark/SortShuffleSuite.scala     |  34 +
 .../scala/org/apache/spark/rdd/RDDSuite.scala   |   6 +-
 .../collection/ExternalAppendOnlyMapSuite.scala |  25 +-
 .../util/collection/ExternalSorterSuite.scala   | 566 ++++++++++++++++
 .../spark/util/collection/FixedHashObject.scala |  25 +
 .../spark/graphx/impl/MessageToPartition.scala  |   2 +-
 .../graphx/impl/RoutingTablePartition.scala     |   2 +-
 project/SparkBuild.scala                        |   1 +
 .../apache/spark/sql/execution/Exchange.scala   |   6 +-
 .../spark/sql/execution/basicOperators.scala    |   2 +-
 35 files changed, 1969 insertions(+), 159 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/Aggregator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index ff0ca11..79c9c45 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -56,18 +56,23 @@ case class Aggregator[K, V, C] (
     } else {
       val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
       combiners.insertAll(iter)
-      // TODO: Make this non optional in a future release
-      Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
-      Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+      // Update task metrics if context is not null
+      // TODO: Make context non optional in a future release
+      Option(context).foreach { c =>
+        c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+        c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+      }
       combiners.iterator
     }
   }
 
   @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0")
-  def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] =
+  def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] =
     combineCombinersByKey(iter, null)
 
-  def combineCombinersByKey(iter: Iterator[(K, C)], context: TaskContext) : Iterator[(K, C)] = {
+  def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
+      : Iterator[(K, C)] =
+  {
     if (!externalSorting) {
       val combiners = new AppendOnlyMap[K,C]
       var kc: Product2[K, C] = null
@@ -85,9 +90,12 @@ case class Aggregator[K, V, C] (
         val pair = iter.next()
         combiners.insert(pair._1, pair._2)
       }
-      // TODO: Make this non optional in a future release
-      Option(context).foreach(c => c.taskMetrics.memoryBytesSpilled = combiners.memoryBytesSpilled)
-      Option(context).foreach(c => c.taskMetrics.diskBytesSpilled = combiners.diskBytesSpilled)
+      // Update task metrics if context is not null
+      // TODO: Make context non-optional in a future release
+      Option(context).foreach { c =>
+        c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
+        c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+      }
       combiners.iterator
     }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fb4c867..b25f081 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -289,7 +289,7 @@ class SparkContext(config: SparkConf) extends Logging {
     value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
     executorEnvs(envKey) = value
   }
-  Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v => 
+  Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
     executorEnvs("SPARK_PREPEND_CLASSES") = v
   }
   // The Mesos scheduler backend relies on this environment variable to set executor memory.
@@ -1203,10 +1203,10 @@ class SparkContext(config: SparkConf) extends Logging {
   /**
    * Clean a closure to make it ready to serialized and send to tasks
    * (removes unreferenced variables in $outer's, updates REPL variables)
-   * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively 
-   * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt> 
+   * If <tt>checkSerializable</tt> is set, <tt>clean</tt> will also proactively
+   * check to see if <tt>f</tt> is serializable and throw a <tt>SparkException</tt>
    * if not.
-   * 
+   *
    * @param f the closure to clean
    * @param checkSerializable whether or not to immediately check <tt>f</tt> for serializability
    * @throws <tt>SparkException<tt> if <tt>checkSerializable</tt> is set but <tt>f</tt> is not

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 31bf8dc..47708cb 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -122,7 +122,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    */
   def sample(withReplacement: Boolean, fraction: Double): JavaPairRDD[K, V] =
     sample(withReplacement, fraction, Utils.random.nextLong)
-    
+
   /**
    * Return a sampled subset of this RDD.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 6388ef8..fabb882 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.rdd
 
+import scala.language.existentials
+
 import java.io.{IOException, ObjectOutputStream}
 
 import scala.collection.mutable.ArrayBuffer
-import scala.language.existentials
 
 import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
 import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
@@ -157,8 +158,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
       for ((it, depNum) <- rddIterators) {
         map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
       }
-      context.taskMetrics.memoryBytesSpilled = map.memoryBytesSpilled
-      context.taskMetrics.diskBytesSpilled = map.diskBytesSpilled
+      context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
+      context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
       new InterruptibleIterator(context,
         map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index d85f962..e98bad2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
 import scala.reflect.ClassTag
 
 import org.apache.spark.{Logging, RangePartitioner}
+import org.apache.spark.annotation.DeveloperApi
 
 /**
  * Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -43,10 +44,10 @@ import org.apache.spark.{Logging, RangePartitioner}
  */
 class OrderedRDDFunctions[K : Ordering : ClassTag,
                           V: ClassTag,
-                          P <: Product2[K, V] : ClassTag](
+                          P <: Product2[K, V] : ClassTag] @DeveloperApi() (
     self: RDD[P])
-  extends Logging with Serializable {
-
+  extends Logging with Serializable
+{
   private val ordering = implicitly[Ordering[K]]
 
   /**
@@ -55,9 +56,12 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
    * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
    * order of the keys).
    */
-  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
+  // TODO: this currently doesn't work on P other than Tuple2!
+  def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size)
+      : RDD[(K, V)] =
+  {
     val part = new RangePartitioner(numPartitions, self, ascending)
-    new ShuffledRDD[K, V, V, P](self, part)
+    new ShuffledRDD[K, V, V](self, part)
       .setKeyOrdering(if (ascending) ordering else ordering.reverse)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 1af4e5f..93af50c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -90,7 +90,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
         new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
       }, preservesPartitioning = true)
     } else {
-      new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
+      new ShuffledRDD[K, V, C](self, partitioner)
         .setSerializer(serializer)
         .setAggregator(aggregator)
         .setMapSideCombine(mapSideCombine)
@@ -425,7 +425,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     if (self.partitioner == Some(partitioner)) {
       self
     } else {
-      new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
+      new ShuffledRDD[K, V, V](self, partitioner)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 726b3f2..74ac970 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag](
       val distributePartition = (index: Int, items: Iterator[T]) => {
         var position = (new Random(index)).nextInt(numPartitions)
         items.map { t =>
-          // Note that the hash code of the key will just be the key itself. The HashPartitioner 
+          // Note that the hash code of the key will just be the key itself. The HashPartitioner
           // will mod it with the number of total partitions.
           position = position + 1
           (position, t)
@@ -341,7 +341,7 @@ abstract class RDD[T: ClassTag](
 
       // include a shuffle step so that our upstream tasks are still distributed
       new CoalescedRDD(
-        new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
+        new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition),
         new HashPartitioner(numPartitions)),
         numPartitions).values
     } else {
@@ -352,8 +352,8 @@ abstract class RDD[T: ClassTag](
   /**
    * Return a sampled subset of this RDD.
    */
-  def sample(withReplacement: Boolean, 
-      fraction: Double, 
+  def sample(withReplacement: Boolean,
+      fraction: Double,
       seed: Long = Utils.random.nextLong): RDD[T] = {
     require(fraction >= 0.0, "Negative fraction value: " + fraction)
     if (withReplacement) {

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index bf02f68..d9fe684 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -37,11 +37,12 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
  * @tparam V the value class.
  * @tparam C the combiner class.
  */
+// TODO: Make this return RDD[Product2[K, C]] or have some way to configure mutable pairs
 @DeveloperApi
-class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
+class ShuffledRDD[K, V, C](
     @transient var prev: RDD[_ <: Product2[K, V]],
     part: Partitioner)
-  extends RDD[P](prev.context, Nil) {
+  extends RDD[(K, C)](prev.context, Nil) {
 
   private var serializer: Option[Serializer] = None
 
@@ -52,25 +53,25 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
   private var mapSideCombine: Boolean = false
 
   /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
-  def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
+  def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
     this.serializer = Option(serializer)
     this
   }
 
   /** Set key ordering for RDD's shuffle. */
-  def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
+  def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C] = {
     this.keyOrdering = Option(keyOrdering)
     this
   }
 
   /** Set aggregator for RDD's shuffle. */
-  def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
+  def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C] = {
     this.aggregator = Option(aggregator)
     this
   }
 
   /** Set mapSideCombine flag for RDD's shuffle. */
-  def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
+  def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C] = {
     this.mapSideCombine = mapSideCombine
     this
   }
@@ -85,11 +86,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
     Array.tabulate[Partition](part.numPartitions)(i => new ShuffledRDDPartition(i))
   }
 
-  override def compute(split: Partition, context: TaskContext): Iterator[P] = {
+  override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
     val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
     SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
       .read()
-      .asInstanceOf[Iterator[P]]
+      .asInstanceOf[Iterator[(K, C)]]
   }
 
   override def clearDependencies() {

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
index 5b0940e..df98d18 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -24,7 +24,7 @@ import org.apache.spark.shuffle._
  * A ShuffleManager using hashing, that creates one output file per reduce partition on each
  * mapper (possibly reusing these across waves of tasks).
  */
-class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
   /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
   override def registerShuffle[K, V, C](
       shuffleId: Int,

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index c805949..e32ad9c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{InterruptibleIterator, TaskContext}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
 
-class HashShuffleReader[K, C](
+private[spark] class HashShuffleReader[K, C](
     handle: BaseShuffleHandle[K, _, C],
     startPartition: Int,
     endPartition: Int,
@@ -47,7 +47,8 @@ class HashShuffleReader[K, C](
     } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
       throw new IllegalStateException("Aggregator is empty for map-side combine")
     } else {
-      iter
+      // Convert the Product2s to pairs since this is what downstream RDDs currently expect
+      iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
     }
 
     // Sort the output if there is a sort ordering defined.

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 9b78228..1923f7c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -24,7 +24,7 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.scheduler.MapStatus
 
-class HashShuffleWriter[K, V](
+private[spark] class HashShuffleWriter[K, V](
     handle: BaseShuffleHandle[K, V, _],
     mapId: Int,
     context: TaskContext)
@@ -33,6 +33,10 @@ class HashShuffleWriter[K, V](
   private val dep = handle.dependency
   private val numOutputSplits = dep.partitioner.numPartitions
   private val metrics = context.taskMetrics
+
+  // Are we in the process of stopping? Because map tasks can call stop() with success = true
+  // and then call stop() with success = false if they get an exception, we want to make sure
+  // we don't try deleting files, etc twice.
   private var stopping = false
 
   private val blockManager = SparkEnv.get.blockManager

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
new file mode 100644
index 0000000..6dcca47
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{DataInputStream, FileInputStream}
+
+import org.apache.spark.shuffle._
+import org.apache.spark.{TaskContext, ShuffleDependency}
+import org.apache.spark.shuffle.hash.HashShuffleReader
+import org.apache.spark.storage.{DiskBlockManager, FileSegment, ShuffleBlockId}
+
+private[spark] class SortShuffleManager extends ShuffleManager {
+  /**
+   * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+   */
+  override def registerShuffle[K, V, C](
+      shuffleId: Int,
+      numMaps: Int,
+      dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+    new BaseShuffleHandle(shuffleId, numMaps, dependency)
+  }
+
+  /**
+   * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+   * Called on executors by reduce tasks.
+   */
+  override def getReader[K, C](
+      handle: ShuffleHandle,
+      startPartition: Int,
+      endPartition: Int,
+      context: TaskContext): ShuffleReader[K, C] = {
+    // We currently use the same block store shuffle fetcher as the hash-based shuffle.
+    new HashShuffleReader(
+      handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+  }
+
+  /** Get a writer for a given partition. Called on executors by map tasks. */
+  override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+      : ShuffleWriter[K, V] = {
+    new SortShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+  }
+
+  /** Remove a shuffle's metadata from the ShuffleManager. */
+  override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+  /** Shut down this ShuffleManager. */
+  override def stop(): Unit = {}
+
+  /** Get the location of a block in a map output file. Uses the index file we create for it. */
+  def getBlockLocation(blockId: ShuffleBlockId, diskManager: DiskBlockManager): FileSegment = {
+    // The block is actually going to be a range of a single map output file for this map, so
+    // figure out the ID of the consolidated file, then the offset within that from our index
+    val consolidatedId = blockId.copy(reduceId = 0)
+    val indexFile = diskManager.getFile(consolidatedId.name + ".index")
+    val in = new DataInputStream(new FileInputStream(indexFile))
+    try {
+      in.skip(blockId.reduceId * 8)
+      val offset = in.readLong()
+      val nextOffset = in.readLong()
+      new FileSegment(diskManager.getFile(consolidatedId), offset, nextOffset - offset)
+    } finally {
+      in.close()
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
new file mode 100644
index 0000000..42fcd07
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{BufferedOutputStream, File, FileOutputStream, DataOutputStream}
+
+import org.apache.spark.{MapOutputTracker, SparkEnv, Logging, TaskContext}
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.collection.ExternalSorter
+
+private[spark] class SortShuffleWriter[K, V, C](
+    handle: BaseShuffleHandle[K, V, C],
+    mapId: Int,
+    context: TaskContext)
+  extends ShuffleWriter[K, V] with Logging {
+
+  private val dep = handle.dependency
+  private val numPartitions = dep.partitioner.numPartitions
+
+  private val blockManager = SparkEnv.get.blockManager
+  private val ser = Serializer.getSerializer(dep.serializer.orNull)
+
+  private val conf = SparkEnv.get.conf
+  private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+  private var sorter: ExternalSorter[K, V, _] = null
+  private var outputFile: File = null
+
+  // Are we in the process of stopping? Because map tasks can call stop() with success = true
+  // and then call stop() with success = false if they get an exception, we want to make sure
+  // we don't try deleting files, etc twice.
+  private var stopping = false
+
+  private var mapStatus: MapStatus = null
+
+  /** Write a bunch of records to this task's output */
+  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+    // Get an iterator with the elements for each partition ID
+    val partitions: Iterator[(Int, Iterator[Product2[K, _]])] = {
+      if (dep.mapSideCombine) {
+        if (!dep.aggregator.isDefined) {
+          throw new IllegalStateException("Aggregator is empty for map-side combine")
+        }
+        sorter = new ExternalSorter[K, V, C](
+          dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
+        sorter.write(records)
+        sorter.partitionedIterator
+      } else {
+        // In this case we pass neither an aggregator nor an ordering to the sorter, because we
+        // don't care whether the keys get sorted in each partition; that will be done on the
+        // reduce side if the operation being run is sortByKey.
+        sorter = new ExternalSorter[K, V, V](
+          None, Some(dep.partitioner), None, dep.serializer)
+        sorter.write(records)
+        sorter.partitionedIterator
+      }
+    }
+
+    // Create a single shuffle file with reduce ID 0 that we'll write all results to. We'll later
+    // serve different ranges of this file using an index file that we create at the end.
+    val blockId = ShuffleBlockId(dep.shuffleId, mapId, 0)
+    outputFile = blockManager.diskBlockManager.getFile(blockId)
+
+    // Track location of each range in the output file
+    val offsets = new Array[Long](numPartitions + 1)
+    val lengths = new Array[Long](numPartitions)
+
+    // Statistics
+    var totalBytes = 0L
+    var totalTime = 0L
+
+    for ((id, elements) <- partitions) {
+      if (elements.hasNext) {
+        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, fileBufferSize)
+        for (elem <- elements) {
+          writer.write(elem)
+        }
+        writer.commit()
+        writer.close()
+        val segment = writer.fileSegment()
+        offsets(id + 1) = segment.offset + segment.length
+        lengths(id) = segment.length
+        totalTime += writer.timeWriting()
+        totalBytes += segment.length
+      } else {
+        // The partition is empty; don't create a new writer to avoid writing headers, etc
+        offsets(id + 1) = offsets(id)
+      }
+    }
+
+    val shuffleMetrics = new ShuffleWriteMetrics
+    shuffleMetrics.shuffleBytesWritten = totalBytes
+    shuffleMetrics.shuffleWriteTime = totalTime
+    context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
+    context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
+    context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+
+    // Write an index file with the offsets of each block, plus a final offset at the end for the
+    // end of the output file. This will be used by SortShuffleManager.getBlockLocation to figure
+    // out where each block begins and ends.
+
+    val diskBlockManager = blockManager.diskBlockManager
+    val indexFile = diskBlockManager.getFile(blockId.name + ".index")
+    val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+    try {
+      var i = 0
+      while (i < numPartitions + 1) {
+        out.writeLong(offsets(i))
+        i += 1
+      }
+    } finally {
+      out.close()
+    }
+
+    // Register our map output with the ShuffleBlockManager, which handles cleaning it over time
+    blockManager.shuffleBlockManager.addCompletedMap(dep.shuffleId, mapId, numPartitions)
+
+    mapStatus = new MapStatus(blockManager.blockManagerId,
+      lengths.map(MapOutputTracker.compressSize))
+  }
+
+  /** Close this writer, passing along whether the map completed */
+  override def stop(success: Boolean): Option[MapStatus] = {
+    try {
+      if (stopping) {
+        return None
+      }
+      stopping = true
+      if (success) {
+        return Option(mapStatus)
+      } else {
+        // The map task failed, so delete our output file if we created one
+        if (outputFile != null) {
+          outputFile.delete()
+        }
+        return None
+      }
+    } finally {
+      // Clean up our sorter, which may have its own intermediate files
+      if (sorter != null) {
+        sorter.stop()
+        sorter = null
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/storage/BlockId.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 42ec181..c1756ac 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -54,12 +54,16 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
 }
 
 @DeveloperApi
-case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int)
-  extends BlockId {
+case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
   def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
 }
 
 @DeveloperApi
+case class ShuffleIndexBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
+  def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index"
+}
+
+@DeveloperApi
 case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
   def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
 }
@@ -88,6 +92,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
   val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -99,6 +104,8 @@ object BlockId {
       RDDBlockId(rddId.toInt, splitIndex.toInt)
     case SHUFFLE(shuffleId, mapId, reduceId) =>
       ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
+    case SHUFFLE_INDEX(shuffleId, mapId, reduceId) =>
+      ShuffleIndexBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
     case BROADCAST(broadcastId, field) =>
       BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
     case TASKRESULT(taskId) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 2e7ed75..4d66cce 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -21,10 +21,11 @@ import java.io.File
 import java.text.SimpleDateFormat
 import java.util.{Date, Random, UUID}
 
-import org.apache.spark.Logging
+import org.apache.spark.{SparkEnv, Logging}
 import org.apache.spark.executor.ExecutorExitCode
 import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
 import org.apache.spark.util.Utils
+import org.apache.spark.shuffle.sort.SortShuffleManager
 
 /**
  * Creates and maintains the logical mapping between logical blocks and physical on-disk
@@ -34,11 +35,13 @@ import org.apache.spark.util.Utils
  *
  * @param rootDirs The directories to use for storing block files. Data will be hashed among these.
  */
-private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootDirs: String)
+private[spark] class DiskBlockManager(shuffleBlockManager: ShuffleBlockManager, rootDirs: String)
   extends PathResolver with Logging {
 
   private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
-  private val subDirsPerLocalDir = shuffleManager.conf.getInt("spark.diskStore.subDirectories", 64)
+
+  private val subDirsPerLocalDir =
+    shuffleBlockManager.conf.getInt("spark.diskStore.subDirectories", 64)
 
   /* Create one local directory for each path mentioned in spark.local.dir; then, inside this
    * directory, create multiple subdirectories that we will hash files into, in order to avoid
@@ -54,13 +57,19 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
   addShutdownHook()
 
   /**
-   * Returns the physical file segment in which the given BlockId is located.
-   * If the BlockId has been mapped to a specific FileSegment, that will be returned.
-   * Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
+   * Returns the physical file segment in which the given BlockId is located. If the BlockId has
+   * been mapped to a specific FileSegment by the shuffle layer, that will be returned.
+   * Otherwise, we assume the Block is mapped to the whole file identified by the BlockId.
    */
   def getBlockLocation(blockId: BlockId): FileSegment = {
-    if (blockId.isShuffle && shuffleManager.consolidateShuffleFiles) {
-      shuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
+    val env = SparkEnv.get  // NOTE: can be null in unit tests
+    if (blockId.isShuffle && env != null && env.shuffleManager.isInstanceOf[SortShuffleManager]) {
+      // For sort-based shuffle, let it figure out its blocks
+      val sortShuffleManager = env.shuffleManager.asInstanceOf[SortShuffleManager]
+      sortShuffleManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId], this)
+    } else if (blockId.isShuffle && shuffleBlockManager.consolidateShuffleFiles) {
+      // For hash-based shuffle with consolidated files, ShuffleBlockManager takes care of this
+      shuffleBlockManager.getBlockLocation(blockId.asInstanceOf[ShuffleBlockId])
     } else {
       val file = getFile(blockId.name)
       new FileSegment(file, 0, file.length())
@@ -99,13 +108,18 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD
     getBlockLocation(blockId).file.exists()
   }
 
-  /** List all the blocks currently stored on disk by the disk manager. */
-  def getAllBlocks(): Seq[BlockId] = {
+  /** List all the files currently stored on disk by the disk manager. */
+  def getAllFiles(): Seq[File] = {
     // Get all the files inside the array of array of directories
     subDirs.flatten.filter(_ != null).flatMap { dir =>
-      val files = dir.list()
+      val files = dir.listFiles()
       if (files != null) files else Seq.empty
-    }.map(BlockId.apply)
+    }
+  }
+
+  /** List all the blocks currently stored on disk by the disk manager. */
+  def getAllBlocks(): Seq[BlockId] = {
+    getAllFiles().map(f => BlockId(f.getName))
   }
 
   /** Produces a unique block id and File suitable for intermediate results. */

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 35910e5..7beb55c 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -28,6 +28,7 @@ import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, PrimitiveVector}
+import org.apache.spark.shuffle.sort.SortShuffleManager
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -58,6 +59,7 @@ private[spark] trait ShuffleWriterGroup {
  * each block stored in each file. In order to find the location of a shuffle block, we search the
  * files within a ShuffleFileGroups associated with the block's reducer.
  */
+// TODO: Factor this into a separate class for each ShuffleManager implementation
 private[spark]
 class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   def conf = blockManager.conf
@@ -67,6 +69,10 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   val consolidateShuffleFiles =
     conf.getBoolean("spark.shuffle.consolidateFiles", false)
 
+  // Are we using sort-based shuffle?
+  val sortBasedShuffle =
+    conf.get("spark.shuffle.manager", "") == classOf[SortShuffleManager].getName
+
   private val bufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
 
   /**
@@ -91,6 +97,20 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   private val metadataCleaner =
     new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
 
+  /**
+   * Register a completed map without getting a ShuffleWriterGroup. Used by sort-based shuffle
+   * because it just writes a single file by itself.
+   */
+  def addCompletedMap(shuffleId: Int, mapId: Int, numBuckets: Int): Unit = {
+    shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
+    val shuffleState = shuffleStates(shuffleId)
+    shuffleState.completedMapTasks.add(mapId)
+  }
+
+  /**
+   * Get a ShuffleWriterGroup for the given map task, which will register it as complete
+   * when the writers are closed successfully
+   */
   def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = {
     new ShuffleWriterGroup {
       shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
@@ -182,7 +202,14 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging {
   private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = {
     shuffleStates.get(shuffleId) match {
       case Some(state) =>
-        if (consolidateShuffleFiles) {
+        if (sortBasedShuffle) {
+          // There's a single block ID for each map, plus an index file for it
+          for (mapId <- state.completedMapTasks) {
+            val blockId = new ShuffleBlockId(shuffleId, mapId, 0)
+            blockManager.diskBlockManager.getFile(blockId).delete()
+            blockManager.diskBlockManager.getFile(blockId.name + ".index").delete()
+          }
+        } else if (consolidateShuffleFiles) {
           for (fileGroup <- state.allFileGroups; file <- fileGroup.files) {
             file.delete()
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6f263c3..b34512e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -79,12 +79,16 @@ class ExternalAppendOnlyMap[K, V, C](
     (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
   }
 
-  // Number of pairs in the in-memory map
-  private var numPairsInMemory = 0L
+  // Number of pairs inserted since last spill; note that we count them even if a value is merged
+  // with a previous key in case we're doing something like groupBy where the result grows
+  private var elementsRead = 0L
 
   // Number of in-memory pairs inserted before tracking the map's shuffle memory usage
   private val trackMemoryThreshold = 1000
 
+  // How much of the shared memory pool this collection has claimed
+  private var myMemoryThreshold = 0L
+
   /**
    * Size of object batches when reading/writing from serializers.
    *
@@ -106,7 +110,6 @@ class ExternalAppendOnlyMap[K, V, C](
   private val fileBufferSize = sparkConf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
   private val keyComparator = new HashComparator[K]
   private val ser = serializer.newInstance()
-  private val threadId = Thread.currentThread().getId
 
   /**
    * Insert the given key and value into the map.
@@ -134,31 +137,35 @@ class ExternalAppendOnlyMap[K, V, C](
 
     while (entries.hasNext) {
       curEntry = entries.next()
-      if (numPairsInMemory > trackMemoryThreshold && currentMap.atGrowThreshold) {
-        val mapSize = currentMap.estimateSize()
+      if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+          currentMap.estimateSize() >= myMemoryThreshold)
+      {
+        val currentSize = currentMap.estimateSize()
         var shouldSpill = false
         val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
 
         // Atomically check whether there is sufficient memory in the global pool for
         // this map to grow and, if possible, allocate the required amount
         shuffleMemoryMap.synchronized {
+          val threadId = Thread.currentThread().getId
           val previouslyOccupiedMemory = shuffleMemoryMap.get(threadId)
           val availableMemory = maxMemoryThreshold -
             (shuffleMemoryMap.values.sum - previouslyOccupiedMemory.getOrElse(0L))
 
-          // Assume map growth factor is 2x
-          shouldSpill = availableMemory < mapSize * 2
+          // Try to allocate at least 2x more memory, otherwise spill
+          shouldSpill = availableMemory < currentSize * 2
           if (!shouldSpill) {
-            shuffleMemoryMap(threadId) = mapSize * 2
+            shuffleMemoryMap(threadId) = currentSize * 2
+            myMemoryThreshold = currentSize * 2
           }
         }
         // Do not synchronize spills
         if (shouldSpill) {
-          spill(mapSize)
+          spill(currentSize)
         }
       }
       currentMap.changeValue(curEntry._1, update)
-      numPairsInMemory += 1
+      elementsRead += 1
     }
   }
 
@@ -178,9 +185,10 @@ class ExternalAppendOnlyMap[K, V, C](
   /**
    * Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
    */
-  private def spill(mapSize: Long) {
+  private def spill(mapSize: Long): Unit = {
     spillCount += 1
-    logWarning("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
+    val threadId = Thread.currentThread().getId
+    logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so far)"
       .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
     val (blockId, file) = diskBlockManager.createTempBlock()
     var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize)
@@ -227,7 +235,9 @@ class ExternalAppendOnlyMap[K, V, C](
     shuffleMemoryMap.synchronized {
       shuffleMemoryMap(Thread.currentThread().getId) = 0
     }
-    numPairsInMemory = 0
+    myMemoryThreshold = 0
+
+    elementsRead = 0
     _memoryBytesSpilled += mapSize
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
new file mode 100644
index 0000000..54c3310
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -0,0 +1,662 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io._
+import java.util.Comparator
+
+import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable
+
+import com.google.common.io.ByteStreams
+
+import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.BlockId
+
+/**
+ * Sorts and potentially merges a number of key-value pairs of type (K, V) to produce key-combiner
+ * pairs of type (K, C). Uses a Partitioner to first group the keys into partitions, and then
+ * optionally sorts keys within each partition using a custom Comparator. Can output a single
+ * partitioned file with a different byte range for each partition, suitable for shuffle fetches.
+ *
+ * If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
+ *
+ * @param aggregator optional Aggregator with combine functions to use for merging data
+ * @param partitioner optional Partitioner; if given, sort by partition ID and then key
+ * @param ordering optional Ordering to sort keys within each partition; should be a total ordering
+ * @param serializer serializer to use when spilling to disk
+ *
+ * Note that if an Ordering is given, we'll always sort using it, so only provide it if you really
+ * want the output keys to be sorted. In a map task without map-side combine for example, you
+ * probably want to pass None as the ordering to avoid extra sorting. On the other hand, if you do
+ * want to do combining, having an Ordering is more efficient than not having it.
+ *
+ * At a high level, this class works as follows:
+ *
+ * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
+ *   we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
+ *   we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
+ *   avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ *
+ * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
+ *   by partition ID and possibly second by key or by hash code of the key, if we want to do
+ *   aggregation. For each file, we track how many objects were in each partition in memory, so we
+ *   don't have to write out the partition ID for every element.
+ *
+ * - When the user requests an iterator, the spilled files are merged, along with any remaining
+ *   in-memory data, using the same sort order defined above (unless both sorting and aggregation
+ *   are disabled). If we need to aggregate by key, we either use a total ordering from the
+ *   ordering parameter, or read the keys with the same hash code and compare them with each other
+ *   for equality to merge values.
+ *
+ * - Users are expected to call stop() at the end to delete all the intermediate files.
+ */
+private[spark] class ExternalSorter[K, V, C](
+    aggregator: Option[Aggregator[K, V, C]] = None,
+    partitioner: Option[Partitioner] = None,
+    ordering: Option[Ordering[K]] = None,
+    serializer: Option[Serializer] = None) extends Logging {
+
+  private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
+  private val shouldPartition = numPartitions > 1
+
+  private val blockManager = SparkEnv.get.blockManager
+  private val diskBlockManager = blockManager.diskBlockManager
+  private val ser = Serializer.getSerializer(serializer)
+  private val serInstance = ser.newInstance()
+
+  private val conf = SparkEnv.get.conf
+  private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
+  private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 100) * 1024
+
+  // Size of object batches when reading/writing from serializers.
+  //
+  // Objects are written in batches, with each batch using its own serialization stream. This
+  // cuts down on the size of reference-tracking maps constructed when deserializing a stream.
+  //
+  // NOTE: Setting this too low can cause excessive copying when serializing, since some serializers
+  // grow internal data structures by growing + copying every time the number of objects doubles.
+  private val serializerBatchSize = conf.getLong("spark.shuffle.spill.batchSize", 10000)
+
+  private def getPartition(key: K): Int = {
+    if (shouldPartition) partitioner.get.getPartition(key) else 0
+  }
+
+  // Data structures to store in-memory objects before we spill. Depending on whether we have an
+  // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
+  // store them in an array buffer.
+  private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+  private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+
+  // Number of pairs read from input since last spill; note that we count them even if a value is
+  // merged with a previous key in case we're doing something like groupBy where the result grows
+  private var elementsRead = 0L
+
+  // What threshold of elementsRead we start estimating map size at.
+  private val trackMemoryThreshold = 1000
+
+  // Spilling statistics
+  private var spillCount = 0
+  private var _memoryBytesSpilled = 0L
+  private var _diskBytesSpilled = 0L
+
+  // Collective memory threshold shared across all running tasks
+  private val maxMemoryThreshold = {
+    val memoryFraction = conf.getDouble("spark.shuffle.memoryFraction", 0.2)
+    val safetyFraction = conf.getDouble("spark.shuffle.safetyFraction", 0.8)
+    (Runtime.getRuntime.maxMemory * memoryFraction * safetyFraction).toLong
+  }
+
+  // How much of the shared memory pool this collection has claimed
+  private var myMemoryThreshold = 0L
+
+  // A comparator for keys K that orders them within a partition to allow aggregation or sorting.
+  // Can be a partial ordering by hash code if a total ordering is not provided through by the
+  // user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
+  // non-equal keys also have this, so we need to do a later pass to find truly equal keys).
+  // Note that we ignore this if no aggregator and no ordering are given.
+  private val keyComparator: Comparator[K] = ordering.getOrElse(new Comparator[K] {
+    override def compare(a: K, b: K): Int = {
+      val h1 = if (a == null) 0 else a.hashCode()
+      val h2 = if (b == null) 0 else b.hashCode()
+      h1 - h2
+    }
+  })
+
+  // A comparator for (Int, K) elements that orders them by partition and then possibly by key
+  private val partitionKeyComparator: Comparator[(Int, K)] = {
+    if (ordering.isDefined || aggregator.isDefined) {
+      // Sort by partition ID then key comparator
+      new Comparator[(Int, K)] {
+        override def compare(a: (Int, K), b: (Int, K)): Int = {
+          val partitionDiff = a._1 - b._1
+          if (partitionDiff != 0) {
+            partitionDiff
+          } else {
+            keyComparator.compare(a._2, b._2)
+          }
+        }
+      }
+    } else {
+      // Just sort it by partition ID
+      new Comparator[(Int, K)] {
+        override def compare(a: (Int, K), b: (Int, K)): Int = {
+          a._1 - b._1
+        }
+      }
+    }
+  }
+
+  // Information about a spilled file. Includes sizes in bytes of "batches" written by the
+  // serializer as we periodically reset its stream, as well as number of elements in each
+  // partition, used to efficiently keep track of partitions when merging.
+  private[this] case class SpilledFile(
+    file: File,
+    blockId: BlockId,
+    serializerBatchSizes: Array[Long],
+    elementsPerPartition: Array[Long])
+  private val spills = new ArrayBuffer[SpilledFile]
+
+  def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+    // TODO: stop combining if we find that the reduction factor isn't high
+    val shouldCombine = aggregator.isDefined
+
+    if (shouldCombine) {
+      // Combine values in-memory first using our AppendOnlyMap
+      val mergeValue = aggregator.get.mergeValue
+      val createCombiner = aggregator.get.createCombiner
+      var kv: Product2[K, V] = null
+      val update = (hadValue: Boolean, oldValue: C) => {
+        if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
+      }
+      while (records.hasNext) {
+        elementsRead += 1
+        kv = records.next()
+        map.changeValue((getPartition(kv._1), kv._1), update)
+        maybeSpill(usingMap = true)
+      }
+    } else {
+      // Stick values into our buffer
+      while (records.hasNext) {
+        elementsRead += 1
+        val kv = records.next()
+        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+        maybeSpill(usingMap = false)
+      }
+    }
+  }
+
+  /**
+   * Spill the current in-memory collection to disk if needed.
+   *
+   * @param usingMap whether we're using a map or buffer as our current in-memory collection
+   */
+  private def maybeSpill(usingMap: Boolean): Unit = {
+    if (!spillingEnabled) {
+      return
+    }
+
+    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+
+    // TODO: factor this out of both here and ExternalAppendOnlyMap
+    if (elementsRead > trackMemoryThreshold && elementsRead % 32 == 0 &&
+        collection.estimateSize() >= myMemoryThreshold)
+    {
+      // TODO: This logic doesn't work if there are two external collections being used in the same
+      // task (e.g. to read shuffle output and write it out into another shuffle) [SPARK-2711]
+
+      val currentSize = collection.estimateSize()
+      var shouldSpill = false
+      val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+
+      // Atomically check whether there is sufficient memory in the global pool for
+      // us to double our threshold
+      shuffleMemoryMap.synchronized {
+        val threadId = Thread.currentThread().getId
+        val previouslyClaimedMemory = shuffleMemoryMap.get(threadId)
+        val availableMemory = maxMemoryThreshold -
+          (shuffleMemoryMap.values.sum - previouslyClaimedMemory.getOrElse(0L))
+
+        // Try to allocate at least 2x more memory, otherwise spill
+        shouldSpill = availableMemory < currentSize * 2
+        if (!shouldSpill) {
+          shuffleMemoryMap(threadId) = currentSize * 2
+          myMemoryThreshold = currentSize * 2
+        }
+      }
+      // Do not hold lock during spills
+      if (shouldSpill) {
+        spill(currentSize, usingMap)
+      }
+    }
+  }
+
+  /**
+   * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
+   *
+   * @param usingMap whether we're using a map or buffer as our current in-memory collection
+   */
+  private def spill(memorySize: Long, usingMap: Boolean): Unit = {
+    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+    val memorySize = collection.estimateSize()
+
+    spillCount += 1
+    val threadId = Thread.currentThread().getId
+    logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s so far)"
+      .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount > 1) "s" else ""))
+    val (blockId, file) = diskBlockManager.createTempBlock()
+    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+    var objectsWritten = 0   // Objects written since the last flush
+
+    // List of batch sizes (bytes) in the order they are written to disk
+    val batchSizes = new ArrayBuffer[Long]
+
+    // How many elements we have in each partition
+    val elementsPerPartition = new Array[Long](numPartitions)
+
+    // Flush the disk writer's contents to disk, and update relevant variables
+    def flush() = {
+      writer.commit()
+      val bytesWritten = writer.bytesWritten
+      batchSizes.append(bytesWritten)
+      _diskBytesSpilled += bytesWritten
+      objectsWritten = 0
+    }
+
+    try {
+      val it = collection.destructiveSortedIterator(partitionKeyComparator)
+      while (it.hasNext) {
+        val elem = it.next()
+        val partitionId = elem._1._1
+        val key = elem._1._2
+        val value = elem._2
+        writer.write(key)
+        writer.write(value)
+        elementsPerPartition(partitionId) += 1
+        objectsWritten += 1
+
+        if (objectsWritten == serializerBatchSize) {
+          flush()
+          writer.close()
+          writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+        }
+      }
+      if (objectsWritten > 0) {
+        flush()
+      }
+      writer.close()
+    } catch {
+      case e: Exception =>
+        writer.close()
+        file.delete()
+        throw e
+    }
+
+    if (usingMap) {
+      map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+    } else {
+      buffer = new SizeTrackingPairBuffer[(Int, K), C]
+    }
+
+    // Reset the amount of shuffle memory used by this map in the global pool
+    val shuffleMemoryMap = SparkEnv.get.shuffleMemoryMap
+    shuffleMemoryMap.synchronized {
+      shuffleMemoryMap(Thread.currentThread().getId) = 0
+    }
+    myMemoryThreshold = 0
+
+    spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
+    _memoryBytesSpilled += memorySize
+  }
+
+  /**
+   * Merge a sequence of sorted files, giving an iterator over partitions and then over elements
+   * inside each partition. This can be used to either write out a new file or return data to
+   * the user.
+   *
+   * Returns an iterator over all the data written to this object, grouped by partition. For each
+   * partition we then have an iterator over its contents, and these are expected to be accessed
+   * in order (you can't "skip ahead" to one partition without reading the previous one).
+   * Guaranteed to return a key-value pair for each partition, in order of partition ID.
+   */
+  private def merge(spills: Seq[SpilledFile], inMemory: Iterator[((Int, K), C)])
+      : Iterator[(Int, Iterator[Product2[K, C]])] = {
+    val readers = spills.map(new SpillReader(_))
+    val inMemBuffered = inMemory.buffered
+    (0 until numPartitions).iterator.map { p =>
+      val inMemIterator = new IteratorForPartition(p, inMemBuffered)
+      val iterators = readers.map(_.readNextPartition()) ++ Seq(inMemIterator)
+      if (aggregator.isDefined) {
+        // Perform partial aggregation across partitions
+        (p, mergeWithAggregation(
+          iterators, aggregator.get.mergeCombiners, keyComparator, ordering.isDefined))
+      } else if (ordering.isDefined) {
+        // No aggregator given, but we have an ordering (e.g. used by reduce tasks in sortByKey);
+        // sort the elements without trying to merge them
+        (p, mergeSort(iterators, ordering.get))
+      } else {
+        (p, iterators.iterator.flatten)
+      }
+    }
+  }
+
+  /**
+   * Merge-sort a sequence of (K, C) iterators using a given a comparator for the keys.
+   */
+  private def mergeSort(iterators: Seq[Iterator[Product2[K, C]]], comparator: Comparator[K])
+      : Iterator[Product2[K, C]] =
+  {
+    val bufferedIters = iterators.filter(_.hasNext).map(_.buffered)
+    type Iter = BufferedIterator[Product2[K, C]]
+    val heap = new mutable.PriorityQueue[Iter]()(new Ordering[Iter] {
+      // Use the reverse of comparator.compare because PriorityQueue dequeues the max
+      override def compare(x: Iter, y: Iter): Int = -comparator.compare(x.head._1, y.head._1)
+    })
+    heap.enqueue(bufferedIters: _*)  // Will contain only the iterators with hasNext = true
+    new Iterator[Product2[K, C]] {
+      override def hasNext: Boolean = !heap.isEmpty
+
+      override def next(): Product2[K, C] = {
+        if (!hasNext) {
+          throw new NoSuchElementException
+        }
+        val firstBuf = heap.dequeue()
+        val firstPair = firstBuf.next()
+        if (firstBuf.hasNext) {
+          heap.enqueue(firstBuf)
+        }
+        firstPair
+      }
+    }
+  }
+
+  /**
+   * Merge a sequence of (K, C) iterators by aggregating values for each key, assuming that each
+   * iterator is sorted by key with a given comparator. If the comparator is not a total ordering
+   * (e.g. when we sort objects by hash code and different keys may compare as equal although
+   * they're not), we still merge them by doing equality tests for all keys that compare as equal.
+   */
+  private def mergeWithAggregation(
+      iterators: Seq[Iterator[Product2[K, C]]],
+      mergeCombiners: (C, C) => C,
+      comparator: Comparator[K],
+      totalOrder: Boolean)
+      : Iterator[Product2[K, C]] =
+  {
+    if (!totalOrder) {
+      // We only have a partial ordering, e.g. comparing the keys by hash code, which means that
+      // multiple distinct keys might be treated as equal by the ordering. To deal with this, we
+      // need to read all keys considered equal by the ordering at once and compare them.
+      new Iterator[Iterator[Product2[K, C]]] {
+        val sorted = mergeSort(iterators, comparator).buffered
+
+        // Buffers reused across elements to decrease memory allocation
+        val keys = new ArrayBuffer[K]
+        val combiners = new ArrayBuffer[C]
+
+        override def hasNext: Boolean = sorted.hasNext
+
+        override def next(): Iterator[Product2[K, C]] = {
+          if (!hasNext) {
+            throw new NoSuchElementException
+          }
+          keys.clear()
+          combiners.clear()
+          val firstPair = sorted.next()
+          keys += firstPair._1
+          combiners += firstPair._2
+          val key = firstPair._1
+          while (sorted.hasNext && comparator.compare(sorted.head._1, key) == 0) {
+            val pair = sorted.next()
+            var i = 0
+            var foundKey = false
+            while (i < keys.size && !foundKey) {
+              if (keys(i) == pair._1) {
+                combiners(i) = mergeCombiners(combiners(i), pair._2)
+                foundKey = true
+              }
+              i += 1
+            }
+            if (!foundKey) {
+              keys += pair._1
+              combiners += pair._2
+            }
+          }
+
+          // Note that we return an iterator of elements since we could've had many keys marked
+          // equal by the partial order; we flatten this below to get a flat iterator of (K, C).
+          keys.iterator.zip(combiners.iterator)
+        }
+      }.flatMap(i => i)
+    } else {
+      // We have a total ordering, so the objects with the same key are sequential.
+      new Iterator[Product2[K, C]] {
+        val sorted = mergeSort(iterators, comparator).buffered
+
+        override def hasNext: Boolean = sorted.hasNext
+
+        override def next(): Product2[K, C] = {
+          if (!hasNext) {
+            throw new NoSuchElementException
+          }
+          val elem = sorted.next()
+          val k = elem._1
+          var c = elem._2
+          while (sorted.hasNext && sorted.head._1 == k) {
+            c = mergeCombiners(c, sorted.head._2)
+          }
+          (k, c)
+        }
+      }
+    }
+  }
+
+  /**
+   * An internal class for reading a spilled file partition by partition. Expects all the
+   * partitions to be requested in order.
+   */
+  private[this] class SpillReader(spill: SpilledFile) {
+    val fileStream = new FileInputStream(spill.file)
+    val bufferedStream = new BufferedInputStream(fileStream, fileBufferSize)
+
+    // Track which partition and which batch stream we're in. These will be the indices of
+    // the next element we will read. We'll also store the last partition read so that
+    // readNextPartition() can figure out what partition that was from.
+    var partitionId = 0
+    var indexInPartition = 0L
+    var batchStreamsRead = 0
+    var indexInBatch = 0
+    var lastPartitionId = 0
+
+    skipToNextPartition()
+
+    // An intermediate stream that reads from exactly one batch
+    // This guards against pre-fetching and other arbitrary behavior of higher level streams
+    var batchStream = nextBatchStream()
+    var compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+    var deserStream = serInstance.deserializeStream(compressedStream)
+    var nextItem: (K, C) = null
+    var finished = false
+
+    /** Construct a stream that only reads from the next batch */
+    def nextBatchStream(): InputStream = {
+      if (batchStreamsRead < spill.serializerBatchSizes.length) {
+        batchStreamsRead += 1
+        ByteStreams.limit(bufferedStream, spill.serializerBatchSizes(batchStreamsRead - 1))
+      } else {
+        // No more batches left; give an empty stream
+        bufferedStream
+      }
+    }
+
+    /**
+     * Update partitionId if we have reached the end of our current partition, possibly skipping
+     * empty partitions on the way.
+     */
+    private def skipToNextPartition() {
+      while (partitionId < numPartitions &&
+          indexInPartition == spill.elementsPerPartition(partitionId)) {
+        partitionId += 1
+        indexInPartition = 0L
+      }
+    }
+
+    /**
+     * Return the next (K, C) pair from the deserialization stream and update partitionId,
+     * indexInPartition, indexInBatch and such to match its location.
+     *
+     * If the current batch is drained, construct a stream for the next batch and read from it.
+     * If no more pairs are left, return null.
+     */
+    private def readNextItem(): (K, C) = {
+      if (finished) {
+        return null
+      }
+      val k = deserStream.readObject().asInstanceOf[K]
+      val c = deserStream.readObject().asInstanceOf[C]
+      lastPartitionId = partitionId
+      // Start reading the next batch if we're done with this one
+      indexInBatch += 1
+      if (indexInBatch == serializerBatchSize) {
+        batchStream = nextBatchStream()
+        compressedStream = blockManager.wrapForCompression(spill.blockId, batchStream)
+        deserStream = serInstance.deserializeStream(compressedStream)
+        indexInBatch = 0
+      }
+      // Update the partition location of the element we're reading
+      indexInPartition += 1
+      skipToNextPartition()
+      // If we've finished reading the last partition, remember that we're done
+      if (partitionId == numPartitions) {
+        finished = true
+        deserStream.close()
+      }
+      (k, c)
+    }
+
+    var nextPartitionToRead = 0
+
+    def readNextPartition(): Iterator[Product2[K, C]] = new Iterator[Product2[K, C]] {
+      val myPartition = nextPartitionToRead
+      nextPartitionToRead += 1
+
+      override def hasNext: Boolean = {
+        if (nextItem == null) {
+          nextItem = readNextItem()
+          if (nextItem == null) {
+            return false
+          }
+        }
+        assert(lastPartitionId >= myPartition)
+        // Check that we're still in the right partition; note that readNextItem will have returned
+        // null at EOF above so we would've returned false there
+        lastPartitionId == myPartition
+      }
+
+      override def next(): Product2[K, C] = {
+        if (!hasNext) {
+          throw new NoSuchElementException
+        }
+        val item = nextItem
+        nextItem = null
+        item
+      }
+    }
+  }
+
+  /**
+   * Return an iterator over all the data written to this object, grouped by partition and
+   * aggregated by the requested aggregator. For each partition we then have an iterator over its
+   * contents, and these are expected to be accessed in order (you can't "skip ahead" to one
+   * partition without reading the previous one). Guaranteed to return a key-value pair for each
+   * partition, in order of partition ID.
+   *
+   * For now, we just merge all the spilled files in once pass, but this can be modified to
+   * support hierarchical merging.
+   */
+  def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
+    val usingMap = aggregator.isDefined
+    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+    if (spills.isEmpty) {
+      // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
+      // we don't even need to sort by anything other than partition ID
+      if (!ordering.isDefined) {
+        // The user isn't requested sorted keys, so only sort by partition ID, not key
+        val partitionComparator = new Comparator[(Int, K)] {
+          override def compare(a: (Int, K), b: (Int, K)): Int = {
+            a._1 - b._1
+          }
+        }
+        groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+      } else {
+        // We do need to sort by both partition ID and key
+        groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+      }
+    } else {
+      // General case: merge spilled and in-memory data
+      merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+    }
+  }
+
+  /**
+   * Return an iterator over all the data written to this object, aggregated by our aggregator.
+   */
+  def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
+
+  def stop(): Unit = {
+    spills.foreach(s => s.file.delete())
+    spills.clear()
+  }
+
+  def memoryBytesSpilled: Long = _memoryBytesSpilled
+
+  def diskBytesSpilled: Long = _diskBytesSpilled
+
+  /**
+   * Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
+   * group together the pairs for each partition into a sub-iterator.
+   *
+   * @param data an iterator of elements, assumed to already be sorted by partition ID
+   */
+  private def groupByPartition(data: Iterator[((Int, K), C)])
+      : Iterator[(Int, Iterator[Product2[K, C]])] =
+  {
+    val buffered = data.buffered
+    (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
+  }
+
+  /**
+   * An iterator that reads only the elements for a given partition ID from an underlying buffered
+   * stream, assuming this partition is the next one to be read. Used to make it easier to return
+   * partitioned iterators from our in-memory collection.
+   */
+  private[this] class IteratorForPartition(partitionId: Int, data: BufferedIterator[((Int, K), C)])
+    extends Iterator[Product2[K, C]]
+  {
+    override def hasNext: Boolean = data.hasNext && data.head._1._1 == partitionId
+
+    override def next(): Product2[K, C] = {
+      if (!hasNext) {
+        throw new NoSuchElementException
+      }
+      val elem = data.next()
+      (elem._1._2, elem._2)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index de61e1d..eb4de41 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -20,8 +20,9 @@ package org.apache.spark.util.collection
 /**
  * An append-only map that keeps track of its estimated size in bytes.
  */
-private[spark] class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] with SizeTracker {
-
+private[spark] class SizeTrackingAppendOnlyMap[K, V]
+  extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+{
   override def update(key: K, value: V): Unit = {
     super.update(key, value)
     super.afterUpdate()

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
new file mode 100644
index 0000000..9e9c16c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
+ */
+private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
+  extends SizeTracker with SizeTrackingPairCollection[K, V]
+{
+  require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+  require(initialCapacity >= 1, "Invalid initial capacity")
+
+  // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
+  // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
+  private var capacity = initialCapacity
+  private var curSize = 0
+  private var data = new Array[AnyRef](2 * initialCapacity)
+
+  /** Add an element into the buffer */
+  def insert(key: K, value: V): Unit = {
+    if (curSize == capacity) {
+      growArray()
+    }
+    data(2 * curSize) = key.asInstanceOf[AnyRef]
+    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
+    curSize += 1
+    afterUpdate()
+  }
+
+  /** Total number of elements in buffer */
+  override def size: Int = curSize
+
+  /** Iterate over the elements of the buffer */
+  override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
+    var pos = 0
+
+    override def hasNext: Boolean = pos < curSize
+
+    override def next(): (K, V) = {
+      if (!hasNext) {
+        throw new NoSuchElementException
+      }
+      val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
+      pos += 1
+      pair
+    }
+  }
+
+  /** Double the size of the array because we've reached capacity */
+  private def growArray(): Unit = {
+    if (capacity == (1 << 29)) {
+      // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+      throw new Exception("Can't grow buffer beyond 2^29 elements")
+    }
+    val newCapacity = capacity * 2
+    val newArray = new Array[AnyRef](2 * newCapacity)
+    System.arraycopy(data, 0, newArray, 0, 2 * capacity)
+    data = newArray
+    capacity = newCapacity
+    resetSamples()
+  }
+
+  /** Iterate through the data in a given order. For this class this is not really destructive. */
+  override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
+    new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
+    iterator
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
new file mode 100644
index 0000000..faa4e2b
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A common interface for our size-tracking collections of key-value pairs, which are used in
+ * external operations. These all support estimating the size and obtaining a memory-efficient
+ * sorted iterator.
+ */
+// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
+private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
+  /** Estimate the collection's current memory usage in bytes. */
+  def estimateSize(): Long
+
+  /** Iterate through the data in a given key order. This may destroy the underlying collection. */
+  def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/e9662844/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d1cb2d9..a41914a 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
   test("ShuffledRDD") {
     testRDD(rdd => {
       // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
-      new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
+      new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
     })
   }