You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by pw...@apache.org on 2015/05/01 08:12:03 UTC

[1/2] spark git commit: [SPARK-4550] In sort-based shuffle, store map outputs in serialized form

Repository: spark
Updated Branches:
  refs/heads/master a9fc50552 -> 0a2b15ce4


http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index de26aa3..20fd22b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -19,19 +19,24 @@ package org.apache.spark.util.collection
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.scalatest.{PrivateMethodTester, FunSuite}
-
-import org.apache.spark._
+import org.scalatest.{FunSuite, PrivateMethodTester}
 
 import scala.util.Random
 
+import org.apache.spark._
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+
 class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMethodTester {
-  private def createSparkConf(loadDefaults: Boolean): SparkConf = {
+  private def createSparkConf(loadDefaults: Boolean, kryo: Boolean): SparkConf = {
     val conf = new SparkConf(loadDefaults)
-    // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
-    // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
-    conf.set("spark.serializer.objectStreamReset", "1")
-    conf.set("spark.serializer", "org.apache.spark.serializer.JavaSerializer")
+    if (kryo) {
+      conf.set("spark.serializer", classOf[KryoSerializer].getName)
+    } else {
+      // Make the Java serializer write a reset instruction (TC_RESET) after each object to test
+      // for a bug we had with bytes written past the last object in a batch (SPARK-2792)
+      conf.set("spark.serializer.objectStreamReset", "1")
+      conf.set("spark.serializer", classOf[JavaSerializer].getName)
+    }
     // Ensure that we actually have multiple batches per spill file
     conf.set("spark.shuffle.spill.batchSize", "10")
     conf
@@ -47,8 +52,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(!sorter.invokePrivate(bypassMergeSort()), "sorter bypassed merge-sort")
   }
 
-  test("empty data stream") {
-    val conf = new SparkConf(false)
+  test("empty data stream with kryo ser") {
+    emptyDataStream(createSparkConf(false, true))
+  }
+
+  test("empty data stream with java ser") {
+    emptyDataStream(createSparkConf(false, false))
+  }
+
+  def emptyDataStream(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -81,8 +93,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     sorter4.stop()
   }
 
-  test("few elements per partition") {
-    val conf = createSparkConf(false)
+  test("few elements per partition with kryo ser") {
+    fewElementsPerPartition(createSparkConf(false, true))
+  }
+
+  test("few elements per partition with java ser") {
+    fewElementsPerPartition(createSparkConf(false, false))
+  }
+
+  def fewElementsPerPartition(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -123,8 +142,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     sorter4.stop()
   }
 
-  test("empty partitions with spilling") {
-    val conf = createSparkConf(false)
+  test("empty partitions with spilling with kryo ser") {
+    emptyPartitionsWithSpilling(createSparkConf(false, true))
+  }
+
+  test("empty partitions with spilling with java ser") {
+    emptyPartitionsWithSpilling(createSparkConf(false, false))
+  }
+
+  def emptyPartitionsWithSpilling(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
@@ -149,8 +175,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     sorter.stop()
   }
 
-  test("empty partitions with spilling, bypass merge-sort") {
-    val conf = createSparkConf(false)
+  test("empty partitions with spilling, bypass merge-sort with kryo ser") {
+    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, true))
+  }
+
+  test("empty partitions with spilling, bypass merge-sort with java ser") {
+    emptyPartitionerWithSpillingBypassMergeSort(createSparkConf(false, false))
+  }
+
+  def emptyPartitionerWithSpillingBypassMergeSort(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.spill.initialMemoryThreshold", "512")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
@@ -174,8 +207,17 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     sorter.stop()
   }
 
-  test("spilling in local cluster") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+  test("spilling in local cluster with kryo ser") {
+    // Load defaults, otherwise SPARK_HOME is not found
+    testSpillingInLocalCluster(createSparkConf(true, true))
+  }
+
+  test("spilling in local cluster with java ser") {
+    // Load defaults, otherwise SPARK_HOME is not found
+    testSpillingInLocalCluster(createSparkConf(true, false))
+  }
+
+  def testSpillingInLocalCluster(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
@@ -245,8 +287,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(resultE === (0 until 100000).map(i => (i/4, i)).toSeq)
   }
 
-  test("spilling in local cluster with many reduce tasks") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+  test("spilling in local cluster with many reduce tasks with kryo ser") {
+    spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, true))
+  }
+
+  test("spilling in local cluster with many reduce tasks with java ser") {
+    spillingInLocalClusterWithManyReduceTasks(createSparkConf(true, false))
+  }
+
+  def spillingInLocalClusterWithManyReduceTasks(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
@@ -317,7 +366,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in sorter") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -344,7 +393,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in sorter, bypass merge-sort") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -367,7 +416,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in sorter if there are errors") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -392,7 +441,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in sorter if there are errors, bypass merge-sort") {
-    val conf = createSparkConf(true)  // Load defaults, otherwise SPARK_HOME is not found
+    val conf = createSparkConf(true, false)  // Load defaults, otherwise SPARK_HOME is not found
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -414,7 +463,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in shuffle") {
-    val conf = createSparkConf(false)
+    val conf = createSparkConf(false, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -429,7 +478,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("cleanup of intermediate files in shuffle with errors") {
-    val conf = createSparkConf(false)
+    val conf = createSparkConf(false, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -450,8 +499,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(diskBlockManager.getAllFiles().length === 2)
   }
 
-  test("no partial aggregation or sorting") {
-    val conf = createSparkConf(false)
+  test("no partial aggregation or sorting with kryo ser") {
+    noPartialAggregationOrSorting(createSparkConf(false, true))
+  }
+
+  test("no partial aggregation or sorting with java ser") {
+    noPartialAggregationOrSorting(createSparkConf(false, false))
+  }
+
+  def noPartialAggregationOrSorting(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -465,8 +521,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(results === expected)
   }
 
-  test("partial aggregation without spill") {
-    val conf = createSparkConf(false)
+  test("partial aggregation without spill with kryo ser") {
+    partialAggregationWithoutSpill(createSparkConf(false, true))
+  }
+
+  test("partial aggregation without spill with java ser") {
+    partialAggregationWithoutSpill(createSparkConf(false, false))
+  }
+
+  def partialAggregationWithoutSpill(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -481,8 +544,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(results === expected)
   }
 
-  test("partial aggregation with spill, no ordering") {
-    val conf = createSparkConf(false)
+  test("partial aggregation with spill, no ordering with kryo ser") {
+    partialAggregationWIthSpillNoOrdering(createSparkConf(false, true))
+  }
+
+  test("partial aggregation with spill, no ordering with java ser") {
+    partialAggregationWIthSpillNoOrdering(createSparkConf(false, false))
+  }
+
+  def partialAggregationWIthSpillNoOrdering(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -497,8 +567,16 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(results === expected)
   }
 
-  test("partial aggregation with spill, with ordering") {
-    val conf = createSparkConf(false)
+  test("partial aggregation with spill, with ordering with kryo ser") {
+    partialAggregationWithSpillWithOrdering(createSparkConf(false, true))
+  }
+
+
+  test("partial aggregation with spill, with ordering with java ser") {
+    partialAggregationWithSpillWithOrdering(createSparkConf(false, false))
+  }
+
+  def partialAggregationWithSpillWithOrdering(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -517,8 +595,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(results === expected)
   }
 
-  test("sorting without aggregation, no spill") {
-    val conf = createSparkConf(false)
+  test("sorting without aggregation, no spill with kryo ser") {
+    sortingWithoutAggregationNoSpill(createSparkConf(false, true))
+  }
+
+  test("sorting without aggregation, no spill with java ser") {
+    sortingWithoutAggregationNoSpill(createSparkConf(false, false))
+  }
+
+  def sortingWithoutAggregationNoSpill(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -534,8 +619,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assert(results === expected)
   }
 
-  test("sorting without aggregation, with spill") {
-    val conf = createSparkConf(false)
+  test("sorting without aggregation, with spill with kryo ser") {
+    sortingWithoutAggregationWithSpill(createSparkConf(false, true))
+  }
+
+  test("sorting without aggregation, with spill with java ser") {
+    sortingWithoutAggregationWithSpill(createSparkConf(false, false))
+  }
+
+  def sortingWithoutAggregationWithSpill(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -552,7 +644,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("spilling with hash collisions") {
-    val conf = createSparkConf(true)
+    val conf = createSparkConf(true, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
 
@@ -609,7 +701,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("spilling with many hash collisions") {
-    val conf = createSparkConf(true)
+    val conf = createSparkConf(true, false)
     conf.set("spark.shuffle.memoryFraction", "0.0001")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
 
@@ -632,7 +724,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("spilling with hash collisions using the Int.MaxValue key") {
-    val conf = createSparkConf(true)
+    val conf = createSparkConf(true, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
 
@@ -656,7 +748,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("spilling with null keys and values") {
-    val conf = createSparkConf(true)
+    val conf = createSparkConf(true, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)
 
@@ -685,7 +777,7 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
   }
 
   test("conditions for bypassing merge-sort") {
-    val conf = createSparkConf(false)
+    val conf = createSparkConf(false, false)
     conf.set("spark.shuffle.memoryFraction", "0.001")
     conf.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.SortShuffleManager")
     sc = new SparkContext("local", "test", conf)
@@ -718,8 +810,15 @@ class ExternalSorterSuite extends FunSuite with LocalSparkContext with PrivateMe
     assertDidNotBypassMergeSort(sorter4)
   }
 
-  test("sort without breaking sorting contracts") {
-    val conf = createSparkConf(true)
+  test("sort without breaking sorting contracts with kryo ser") {
+    sortWithoutBreakingSortingContracts(createSparkConf(true, true))
+  }
+
+  test("sort without breaking sorting contracts with java ser") {
+    sortWithoutBreakingSortingContracts(createSparkConf(true, false))
+  }
+
+  def sortWithoutBreakingSortingContracts(conf: SparkConf) {
     conf.set("spark.shuffle.memoryFraction", "0.01")
     conf.set("spark.shuffle.manager", "sort")
     sc = new SparkContext("local-cluster[1,1,512]", "test", conf)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
new file mode 100644
index 0000000..b5a2d9e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/PartitionedSerializedPairBufferSuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream}
+
+import com.google.common.io.ByteStreams
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers._
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.storage.{FileSegment, BlockObjectWriter}
+
+class PartitionedSerializedPairBufferSuite extends FunSuite {
+  test("OrderedInputStream single record") {
+    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+
+    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+    val struct = SomeStruct("something", 5)
+    buffer.insert(4, 10, struct)
+
+    val bytes = ByteStreams.toByteArray(buffer.orderedInputStream)
+
+    val baos = new ByteArrayOutputStream()
+    val stream = serializerInstance.serializeStream(baos)
+    stream.writeObject(10)
+    stream.writeObject(struct)
+    stream.close()
+
+    baos.toByteArray should be (bytes)
+  }
+
+  test("insert single record") {
+    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+    val struct = SomeStruct("something", 5)
+    buffer.insert(4, 10, struct)
+    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
+    elements.size should be (1)
+    elements.head should be (((4, 10), struct))
+  }
+
+  test("insert multiple records") {
+    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+    val struct1 = SomeStruct("something1", 8)
+    buffer.insert(6, 1, struct1)
+    val struct2 = SomeStruct("something2", 9)
+    buffer.insert(4, 2, struct2)
+    val struct3 = SomeStruct("something3", 10)
+    buffer.insert(5, 3, struct3)
+
+    val elements = buffer.partitionedDestructiveSortedIterator(None).toArray
+    elements.size should be (3)
+    elements(0) should be (((4, 2), struct2))
+    elements(1) should be (((5, 3), struct3))
+    elements(2) should be (((6, 1), struct1))
+  }
+
+  test("write single record") {
+    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+    val struct = SomeStruct("something", 5)
+    buffer.insert(4, 10, struct)
+    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
+    val writer = new SimpleBlockObjectWriter
+    assert(it.hasNext)
+    it.nextPartition should be (4)
+    it.writeNext(writer)
+    assert(!it.hasNext)
+
+    val stream = serializerInstance.deserializeStream(writer.getInputStream)
+    stream.readObject[AnyRef]() should be (10)
+    stream.readObject[AnyRef]() should be (struct)
+  }
+
+  test("write multiple records") {
+    val serializerInstance = new KryoSerializer(new SparkConf()).newInstance
+    val buffer = new PartitionedSerializedPairBuffer[Int, SomeStruct](4, 32, serializerInstance)
+    val struct1 = SomeStruct("something1", 8)
+    buffer.insert(6, 1, struct1)
+    val struct2 = SomeStruct("something2", 9)
+    buffer.insert(4, 2, struct2)
+    val struct3 = SomeStruct("something3", 10)
+    buffer.insert(5, 3, struct3)
+
+    val it = buffer.destructiveSortedWritablePartitionedIterator(None)
+    val writer = new SimpleBlockObjectWriter
+    assert(it.hasNext)
+    it.nextPartition should be (4)
+    it.writeNext(writer)
+    assert(it.hasNext)
+    it.nextPartition should be (5)
+    it.writeNext(writer)
+    assert(it.hasNext)
+    it.nextPartition should be (6)
+    it.writeNext(writer)
+    assert(!it.hasNext)
+
+    val stream = serializerInstance.deserializeStream(writer.getInputStream)
+    val iter = stream.asIterator
+    iter.next() should be (2)
+    iter.next() should be (struct2)
+    iter.next() should be (3)
+    iter.next() should be (struct3)
+    iter.next() should be (1)
+    iter.next() should be (struct1)
+    assert(!iter.hasNext)
+  }
+}
+
+case class SomeStruct(val str: String, val num: Int)
+
+class SimpleBlockObjectWriter extends BlockObjectWriter(null) {
+  val baos = new ByteArrayOutputStream()
+
+  override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
+    baos.write(bytes, offs, len)
+  }
+
+  def getInputStream(): InputStream = new ByteArrayInputStream(baos.toByteArray)
+
+  override def open(): BlockObjectWriter = this
+  override def close(): Unit = { }
+  override def isOpen: Boolean = true
+  override def commitAndClose(): Unit = { }
+  override def revertPartialWritesAndClose(): Unit = { }
+  override def fileSegment(): FileSegment = null
+  override def write(key: Any, value: Any): Unit = { }
+  override def recordWritten(): Unit = { }
+  override def write(b: Int): Unit = { }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
index cec97de..9552f41 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala
@@ -50,10 +50,10 @@ private[sql] class Serializer2SerializationStream(
   extends SerializationStream with Logging {
 
   val rowOut = new DataOutputStream(out)
-  val writeKey = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
-  val writeValue = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
+  val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
+  val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
 
-  def writeObject[T: ClassTag](t: T): SerializationStream = {
+  override def writeObject[T: ClassTag](t: T): SerializationStream = {
     val kv = t.asInstanceOf[Product2[Row, Row]]
     writeKey(kv._1)
     writeValue(kv._2)
@@ -61,6 +61,16 @@ private[sql] class Serializer2SerializationStream(
     this
   }
 
+  override def writeKey[T: ClassTag](t: T): SerializationStream = {
+    writeKeyFunc(t.asInstanceOf[Row])
+    this
+  }
+
+  override def writeValue[T: ClassTag](t: T): SerializationStream = {
+    writeValueFunc(t.asInstanceOf[Row])
+    this
+  }
+
   def flush(): Unit = {
     rowOut.flush()
   }
@@ -83,17 +93,27 @@ private[sql] class Serializer2DeserializationStream(
 
   val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
   val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
-  val readKey = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
-  val readValue = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
+  val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
+  val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
 
-  def readObject[T: ClassTag](): T = {
-    readKey()
-    readValue()
+  override def readObject[T: ClassTag](): T = {
+    readKeyFunc()
+    readValueFunc()
 
     (key, value).asInstanceOf[T]
   }
 
-  def close(): Unit = {
+  override def readKey[T: ClassTag](): T = {
+    readKeyFunc()
+    key.asInstanceOf[T]
+  }
+
+  override def readValue[T: ClassTag](): T = {
+    readValueFunc()
+    value.asInstanceOf[T]
+  }
+
+  override def close(): Unit = {
     rowIn.close()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
----------------------------------------------------------------------
diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
index f2d1353..baa9761 100644
--- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
@@ -46,7 +46,8 @@ object StoragePerfTester {
     val totalRecords = dataSizeMb * 1000
     val recordsPerMap = totalRecords / numMaps
 
-    val writeData = "1" * recordLength
+    val writeKey = "1" * (recordLength / 2)
+    val writeValue = "1" * (recordLength / 2)
     val executor = Executors.newFixedThreadPool(numMaps)
 
     val conf = new SparkConf()
@@ -63,7 +64,7 @@ object StoragePerfTester {
         new KryoSerializer(sc.conf), new ShuffleWriteMetrics())
       val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {
-        writers(i % numOutputSplits).write(writeData)
+        writers(i % numOutputSplits).write(writeKey, writeValue)
       }
       writers.map { w =>
         w.commitAndClose()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


[2/2] spark git commit: [SPARK-4550] In sort-based shuffle, store map outputs in serialized form

Posted by pw...@apache.org.
[SPARK-4550] In sort-based shuffle, store map outputs in serialized form

Refer to the JIRA for the design doc and some perf results.

I wanted to call out some of the more possibly controversial changes up front:
* Map outputs are only stored in serialized form when Kryo is in use.  I'm still unsure whether Java-serialized objects can be relocated.  At the very least, Java serialization writes out a stream header which causes problems with the current approach, so I decided to leave investigating this to future work.
* The shuffle now explicitly operates on key-value pairs instead of any object.  Data is written to shuffle files in alternating keys and values instead of key-value tuples.  `BlockObjectWriter.write` now accepts a key argument and a value argument instead of any object.
* The map output buffer can hold a max of Integer.MAX_VALUE bytes.  Though this wouldn't be terribly difficult to change.
* When spilling occurs, the objects that still in memory at merge time end up serialized and deserialized an extra time.

Author: Sandy Ryza <sa...@cloudera.com>

Closes #4450 from sryza/sandy-spark-4550 and squashes the following commits:

8c70dd9 [Sandy Ryza] Fix serialization
9c16fe6 [Sandy Ryza] Fix a couple tests and move getAutoReset to KryoSerializerInstance
6c54e06 [Sandy Ryza] Fix scalastyle
d8462d8 [Sandy Ryza] SPARK-4550


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a2b15ce
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a2b15ce
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a2b15ce

Branch: refs/heads/master
Commit: 0a2b15ce43cf6096e1a7ae060b7c8a4010ce3b92
Parents: a9fc505
Author: Sandy Ryza <sa...@cloudera.com>
Authored: Thu Apr 30 23:14:14 2015 -0700
Committer: Patrick Wendell <pa...@databricks.com>
Committed: Thu Apr 30 23:14:14 2015 -0700

----------------------------------------------------------------------
 .../spark/serializer/KryoSerializer.scala       |  10 +
 .../apache/spark/serializer/Serializer.scala    |  31 +++
 .../spark/shuffle/hash/HashShuffleWriter.scala  |   2 +-
 .../spark/storage/BlockObjectWriter.scala       |  37 ++-
 .../storage/ShuffleBlockFetcherIterator.scala   |   6 +-
 .../spark/util/collection/ChainedBuffer.scala   | 144 +++++++++++
 .../util/collection/ExternalAppendOnlyMap.scala |   6 +-
 .../spark/util/collection/ExternalSorter.scala  | 144 ++++++-----
 .../spark/util/collection/PairIterator.scala    |  24 ++
 .../collection/PartitionedAppendOnlyMap.scala   |  44 ++++
 .../util/collection/PartitionedPairBuffer.scala |  92 +++++++
 .../PartitionedSerializedPairBuffer.scala       | 254 +++++++++++++++++++
 .../collection/SizeTrackingAppendOnlyMap.scala  |   2 +-
 .../collection/SizeTrackingPairBuffer.scala     |  86 -------
 .../collection/SizeTrackingPairCollection.scala |  34 ---
 .../WritablePartitionedPairCollection.scala     | 113 +++++++++
 .../spark/serializer/KryoSerializerSuite.scala  |  15 ++
 .../spark/serializer/TestSerializer.scala       |   4 +-
 .../shuffle/hash/HashShuffleManagerSuite.scala  |  12 +-
 .../spark/storage/BlockObjectWriterSuite.scala  |   8 +-
 .../util/collection/ChainedBufferSuite.scala    | 143 +++++++++++
 .../util/collection/ExternalSorterSuite.scala   | 189 ++++++++++----
 .../PartitionedSerializedPairBufferSuite.scala  | 149 +++++++++++
 .../sql/execution/SparkSqlSerializer2.scala     |  38 ++-
 .../apache/spark/tools/StoragePerfTester.scala  |   5 +-
 25 files changed, 1321 insertions(+), 271 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 754832b..b7bc087 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -200,6 +200,16 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
   override def deserializeStream(s: InputStream): DeserializationStream = {
     new KryoDeserializationStream(kryo, s)
   }
+
+  /**
+   * Returns true if auto-reset is on. The only reason this would be false is if the user-supplied
+   * registrator explicitly turns auto-reset off.
+   */
+  def getAutoReset(): Boolean = {
+    val field = classOf[Kryo].getDeclaredField("autoReset")
+    field.setAccessible(true)
+    field.get(kryo).asInstanceOf[Boolean]
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ca6e971..c381672 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -101,7 +101,12 @@ abstract class SerializerInstance {
  */
 @DeveloperApi
 abstract class SerializationStream {
+  /** The most general-purpose method to write an object. */
   def writeObject[T: ClassTag](t: T): SerializationStream
+  /** Writes the object representing the key of a key-value pair. */
+  def writeKey[T: ClassTag](key: T): SerializationStream = writeObject(key)
+  /** Writes the object representing the value of a key-value pair. */
+  def writeValue[T: ClassTag](value: T): SerializationStream = writeObject(value)
   def flush(): Unit
   def close(): Unit
 
@@ -120,7 +125,12 @@ abstract class SerializationStream {
  */
 @DeveloperApi
 abstract class DeserializationStream {
+  /** The most general-purpose method to read an object. */
   def readObject[T: ClassTag](): T
+  /** Reads the object representing the key of a key-value pair. */
+  def readKey[T: ClassTag](): T = readObject[T]()
+  /** Reads the object representing the value of a key-value pair. */
+  def readValue[T: ClassTag](): T = readObject[T]()
   def close(): Unit
 
   /**
@@ -141,4 +151,25 @@ abstract class DeserializationStream {
       DeserializationStream.this.close()
     }
   }
+
+  /**
+   * Read the elements of this stream through an iterator over key-value pairs. This can only be
+   * called once, as reading each element will consume data from the input source.
+   */
+  def asKeyValueIterator: Iterator[(Any, Any)] = new NextIterator[(Any, Any)] {
+    override protected def getNext() = {
+      try {
+        (readKey[Any](), readValue[Any]())
+      } catch {
+        case eof: EOFException => {
+          finished = true
+          null
+        }
+      }
+    }
+
+    override protected def close() {
+      DeserializationStream.this.close()
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 755f17d..cd27c9e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -63,7 +63,7 @@ private[spark] class HashShuffleWriter[K, V](
 
     for (elem <- iter) {
       val bucketId = dep.partitioner.getPartition(elem._1)
-      shuffle.writers(bucketId).write(elem)
+      shuffle.writers(bucketId).write(elem._1, elem._2)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 1483379..499dd97 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -33,7 +33,7 @@ import org.apache.spark.util.Utils
  * This interface does not support concurrent writes. Also, once the writer has
  * been opened, it cannot be reopened again.
  */
-private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
+private[spark] abstract class BlockObjectWriter(val blockId: BlockId) extends OutputStream {
 
   def open(): BlockObjectWriter
 
@@ -54,9 +54,14 @@ private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
   def revertPartialWritesAndClose()
 
   /**
-   * Writes an object.
+   * Writes a key-value pair.
    */
-  def write(value: Any)
+  def write(key: Any, value: Any)
+
+  /**
+   * Notify the writer that a record worth of bytes has been written with writeBytes.
+   */
+  def recordWritten()
 
   /**
    * Returns the file segment of committed data that this Writer has written.
@@ -203,12 +208,32 @@ private[spark] class DiskBlockObjectWriter(
     }
   }
 
-  override def write(value: Any) {
+  override def write(key: Any, value: Any) {
+    if (!initialized) {
+      open()
+    }
+
+    objOut.writeKey(key)
+    objOut.writeValue(value)
+    numRecordsWritten += 1
+    writeMetrics.incShuffleRecordsWritten(1)
+
+    if (numRecordsWritten % 32 == 0) {
+      updateBytesWritten()
+    }
+  }
+
+  override def write(b: Int): Unit = throw new UnsupportedOperationException()
+
+  override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
     if (!initialized) {
       open()
     }
 
-    objOut.writeObject(value)
+    bs.write(kvBytes, offs, len)
+  }
+
+  override def recordWritten(): Unit = {
     numRecordsWritten += 1
     writeMetrics.incShuffleRecordsWritten(1)
 
@@ -238,7 +263,7 @@ private[spark] class DiskBlockObjectWriter(
   }
 
   // For testing
-  private[spark] def flush() {
+  private[spark] override def flush() {
     objOut.flush()
     bs.flush()
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index f337952..d0faab6 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,14 +17,12 @@
 
 package org.apache.spark.storage
 
-import java.io.{InputStream, IOException}
 import java.util.concurrent.LinkedBlockingQueue
 
 import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
-import scala.util.{Failure, Success, Try}
+import scala.util.{Failure, Try}
 
 import org.apache.spark.{Logging, TaskContext}
-import org.apache.spark.network.BlockTransferService
 import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.serializer.{SerializerInstance, Serializer}
@@ -301,7 +299,7 @@ final class ShuffleBlockFetcherIterator(
         // the scheduler gets a FetchFailedException.
         Try(buf.createInputStream()).map { is0 =>
           val is = blockManager.wrapForCompression(blockId, is0)
-          val iter = serializerInstance.deserializeStream(is).asIterator
+          val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
           CompletionIterator[Any, Iterator[Any]](iter, {
             // Once the iterator is exhausted, release the buffer and set currentResult to null
             // so we don't release it again in cleanup.

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
new file mode 100644
index 0000000..a60bffe
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/ChainedBuffer.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.OutputStream
+
+import scala.collection.mutable.ArrayBuffer
+
+/**
+ * A logical byte buffer that wraps a list of byte arrays. All the byte arrays have equal size. The
+ * advantage of this over a standard ArrayBuffer is that it can grow without claiming large amounts
+ * of memory and needing to copy the full contents. The disadvantage is that the contents don't
+ * occupy a contiguous segment of memory.
+ */
+private[spark] class ChainedBuffer(chunkSize: Int) {
+  private val chunkSizeLog2 = (math.log(chunkSize) / math.log(2)).toInt
+  assert(math.pow(2, chunkSizeLog2).toInt == chunkSize,
+    s"ChainedBuffer chunk size $chunkSize must be a power of two")
+  private val chunks: ArrayBuffer[Array[Byte]] = new ArrayBuffer[Array[Byte]]()
+  private var _size: Int = _
+
+  /**
+   * Feed bytes from this buffer into a BlockObjectWriter.
+   *
+   * @param pos Offset in the buffer to read from.
+   * @param os OutputStream to read into.
+   * @param len Number of bytes to read.
+   */
+  def read(pos: Int, os: OutputStream, len: Int): Unit = {
+    if (pos + len > _size) {
+      throw new IndexOutOfBoundsException(
+        s"Read of $len bytes at position $pos would go past size ${_size} of buffer")
+    }
+    var chunkIndex = pos >> chunkSizeLog2
+    var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+    var written = 0
+    while (written < len) {
+      val toRead = math.min(len - written, chunkSize - posInChunk)
+      os.write(chunks(chunkIndex), posInChunk, toRead)
+      written += toRead
+      chunkIndex += 1
+      posInChunk = 0
+    }
+  }
+
+  /**
+   * Read bytes from this buffer into a byte array.
+   *
+   * @param pos Offset in the buffer to read from.
+   * @param bytes Byte array to read into.
+   * @param offs Offset in the byte array to read to.
+   * @param len Number of bytes to read.
+   */
+  def read(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+    if (pos + len > _size) {
+      throw new IndexOutOfBoundsException(
+        s"Read of $len bytes at position $pos would go past size of buffer")
+    }
+    var chunkIndex = pos >> chunkSizeLog2
+    var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+    var written = 0
+    while (written < len) {
+      val toRead = math.min(len - written, chunkSize - posInChunk)
+      System.arraycopy(chunks(chunkIndex), posInChunk, bytes, offs + written, toRead)
+      written += toRead
+      chunkIndex += 1
+      posInChunk = 0
+    }
+  }
+
+  /**
+   * Write bytes from a byte array into this buffer.
+   *
+   * @param pos Offset in the buffer to write to.
+   * @param bytes Byte array to write from.
+   * @param offs Offset in the byte array to write from.
+   * @param len Number of bytes to write.
+   */
+  def write(pos: Int, bytes: Array[Byte], offs: Int, len: Int): Unit = {
+    if (pos > _size) {
+      throw new IndexOutOfBoundsException(
+        s"Write at position $pos starts after end of buffer ${_size}")
+    }
+    // Grow if needed
+    val endChunkIndex = (pos + len - 1) >> chunkSizeLog2
+    while (endChunkIndex >= chunks.length) {
+      chunks += new Array[Byte](chunkSize)
+    }
+
+    var chunkIndex = pos >> chunkSizeLog2
+    var posInChunk = pos - (chunkIndex << chunkSizeLog2)
+    var written = 0
+    while (written < len) {
+      val toWrite = math.min(len - written, chunkSize - posInChunk)
+      System.arraycopy(bytes, offs + written, chunks(chunkIndex), posInChunk, toWrite)
+      written += toWrite
+      chunkIndex += 1
+      posInChunk = 0
+    }
+
+    _size = math.max(_size, pos + len)
+  }
+
+  /**
+   * Total size of buffer that can be written to without allocating additional memory.
+   */
+  def capacity: Int = chunks.size * chunkSize
+
+  /**
+   * Size of the logical buffer.
+   */
+  def size: Int = _size
+}
+
+/**
+ * Output stream that writes to a ChainedBuffer.
+ */
+private[spark] class ChainedBufferOutputStream(chainedBuffer: ChainedBuffer) extends OutputStream {
+  private var pos = 0
+
+  override def write(b: Int): Unit = {
+    throw new UnsupportedOperationException()
+  }
+
+  override def write(bytes: Array[Byte], offs: Int, len: Int): Unit = {
+    chainedBuffer.write(pos, bytes, offs, len)
+    pos += len
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index f912049..b850973 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -174,7 +174,7 @@ class ExternalAppendOnlyMap[K, V, C](
       val it = currentMap.destructiveSortedIterator(keyComparator)
       while (it.hasNext) {
         val kv = it.next()
-        writer.write(kv)
+        writer.write(kv._1, kv._2)
         objectsWritten += 1
 
         if (objectsWritten == serializerBatchSize) {
@@ -435,7 +435,9 @@ class ExternalAppendOnlyMap[K, V, C](
      */
     private def readNextItem(): (K, C) = {
       try {
-        val item = deserializeStream.readObject().asInstanceOf[(K, C)]
+        val k = deserializeStream.readKey().asInstanceOf[K]
+        val c = deserializeStream.readValue().asInstanceOf[C]
+        val item = (k, c)
         objectsRead += 1
         if (objectsRead == serializerBatchSize) {
           objectsRead = 0

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4ed8a74..b7306cd 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable
 import com.google.common.io.ByteStreams
 
 import org.apache.spark._
-import org.apache.spark.serializer.{DeserializationStream, Serializer}
+import org.apache.spark.serializer._
 import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.storage.{BlockObjectWriter, BlockId}
 
@@ -66,10 +66,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
  *
  * At a high level, this class works internally as follows:
  *
- * - We repeatedly fill up buffers of in-memory data, using either a SizeTrackingAppendOnlyMap if
- *   we want to combine by key, or an simple SizeTrackingBuffer if we don't. Inside these buffers,
- *   we sort elements of type ((Int, K), C) where the Int is the partition ID. This is done to
- *   avoid calling the partitioner multiple times on the same key (e.g. for RangePartitioner).
+ * - We repeatedly fill up buffers of in-memory data, using either a PartitionedAppendOnlyMap if
+ *   we want to combine by key, or a PartitionedSerializedPairBuffer or PartitionedPairBuffer if we
+ *   don't. Inside these buffers, we sort elements by partition ID and then possibly also by key.
+ *   To avoid calling the partitioner multiple times with each key, we store the partition ID
+ *   alongside each record.
  *
  * - When each buffer reaches our memory limit, we spill it to a file. This file is sorted first
  *   by partition ID and possibly second by key or by hash code of the key, if we want to do
@@ -96,7 +97,7 @@ private[spark] class ExternalSorter[K, V, C](
     partitioner: Option[Partitioner] = None,
     ordering: Option[Ordering[K]] = None,
     serializer: Option[Serializer] = None)
-  extends Logging with Spillable[SizeTrackingPairCollection[(Int, K), C]] {
+  extends Logging with Spillable[WritablePartitionedPairCollection[K, C]] {
 
   private val numPartitions = partitioner.map(_.numPartitions).getOrElse(1)
   private val shouldPartition = numPartitions > 1
@@ -126,11 +127,22 @@ private[spark] class ExternalSorter[K, V, C](
     if (shouldPartition) partitioner.get.getPartition(key) else 0
   }
 
+  private val metaInitialRecords = 256
+  private val kvChunkSize = conf.getInt("spark.shuffle.sort.kvChunkSize", 1 << 22) // 4 MB
+  private val useSerializedPairBuffer =
+    !ordering.isDefined && conf.getBoolean("spark.shuffle.sort.serializeMapOutputs", true) &&
+    ser.isInstanceOf[KryoSerializer] &&
+    serInstance.asInstanceOf[KryoSerializerInstance].getAutoReset
+
   // Data structures to store in-memory objects before we spill. Depending on whether we have an
   // Aggregator set, we either put objects into an AppendOnlyMap where we combine them, or we
   // store them in an array buffer.
-  private var map = new SizeTrackingAppendOnlyMap[(Int, K), C]
-  private var buffer = new SizeTrackingPairBuffer[(Int, K), C]
+  private var map = new PartitionedAppendOnlyMap[K, C]
+  private var buffer = if (useSerializedPairBuffer) {
+    new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+  } else {
+    new PartitionedPairBuffer[K, C]
+  }
 
   // Total spilling statistics
   private var _diskBytesSpilled = 0L
@@ -163,33 +175,6 @@ private[spark] class ExternalSorter[K, V, C](
     }
   })
 
-  // A comparator for (Int, K) pairs that orders them by only their partition ID
-  private val partitionComparator: Comparator[(Int, K)] = new Comparator[(Int, K)] {
-    override def compare(a: (Int, K), b: (Int, K)): Int = {
-      a._1 - b._1
-    }
-  }
-
-  // A comparator that orders (Int, K) pairs by partition ID and then possibly by key
-  private val partitionKeyComparator: Comparator[(Int, K)] = {
-    if (ordering.isDefined || aggregator.isDefined) {
-      // Sort by partition ID then key comparator
-      new Comparator[(Int, K)] {
-        override def compare(a: (Int, K), b: (Int, K)): Int = {
-          val partitionDiff = a._1 - b._1
-          if (partitionDiff != 0) {
-            partitionDiff
-          } else {
-            keyComparator.compare(a._2, b._2)
-          }
-        }
-      }
-    } else {
-      // Just sort it by partition ID
-      partitionComparator
-    }
-  }
-
   // Information about a spilled file. Includes sizes in bytes of "batches" written by the
   // serializer as we periodically reset its stream, as well as number of elements in each
   // partition, used to efficiently keep track of partitions when merging.
@@ -221,16 +206,18 @@ private[spark] class ExternalSorter[K, V, C](
     } else if (bypassMergeSort) {
       // SPARK-4479: Also bypass buffering if merge sort is bypassed to avoid defensive copies
       if (records.hasNext) {
-        spillToPartitionFiles(records.map { kv =>
-          ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
-        })
+        spillToPartitionFiles(
+          WritablePartitionedIterator.fromIterator(records.map { kv =>
+            ((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+          })
+        )
       }
     } else {
       // Stick values into our buffer
       while (records.hasNext) {
         addElementsRead()
         val kv = records.next()
-        buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C])
+        buffer.insert(getPartition(kv._1), kv._1, kv._2.asInstanceOf[C])
         maybeSpillCollection(usingMap = false)
       }
     }
@@ -248,11 +235,15 @@ private[spark] class ExternalSorter[K, V, C](
 
     if (usingMap) {
       if (maybeSpill(map, map.estimateSize())) {
-        map = new SizeTrackingAppendOnlyMap[(Int, K), C]
+        map = new PartitionedAppendOnlyMap[K, C]
       }
     } else {
       if (maybeSpill(buffer, buffer.estimateSize())) {
-        buffer = new SizeTrackingPairBuffer[(Int, K), C]
+        buffer = if (useSerializedPairBuffer) {
+          new PartitionedSerializedPairBuffer[K, C](metaInitialRecords, kvChunkSize, serInstance)
+        } else {
+          new PartitionedPairBuffer[K, C]
+        }
       }
     }
   }
@@ -260,7 +251,7 @@ private[spark] class ExternalSorter[K, V, C](
   /**
    * Spill the current in-memory collection to disk, adding a new file to spills, and clear it.
    */
-  override protected[this] def spill(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+  override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
     if (bypassMergeSort) {
       spillToPartitionFiles(collection)
     } else {
@@ -277,7 +268,7 @@ private[spark] class ExternalSorter[K, V, C](
    *
    * @param collection whichever collection we're using (map or buffer)
    */
-  private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
+  private def spillToMergeableFile(collection: WritablePartitionedPairCollection[K, C]): Unit = {
     assert(!bypassMergeSort)
 
     // Because these files may be read during shuffle, their compression must be controlled by
@@ -308,14 +299,10 @@ private[spark] class ExternalSorter[K, V, C](
 
     var success = false
     try {
-      val it = collection.destructiveSortedIterator(partitionKeyComparator)
+      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
       while (it.hasNext) {
-        val elem = it.next()
-        val partitionId = elem._1._1
-        val key = elem._1._2
-        val value = elem._2
-        writer.write(key)
-        writer.write(value)
+        val partitionId = it.nextPartition()
+        it.writeNext(writer)
         elementsPerPartition(partitionId) += 1
         objectsWritten += 1
 
@@ -357,11 +344,11 @@ private[spark] class ExternalSorter[K, V, C](
    *
    * @param collection whichever collection we're using (map or buffer)
    */
-  private def spillToPartitionFiles(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
-    spillToPartitionFiles(collection.iterator)
+  private def spillToPartitionFiles(collection: WritablePartitionedPairCollection[K, C]): Unit = {
+    spillToPartitionFiles(collection.writablePartitionedIterator())
   }
 
-  private def spillToPartitionFiles(iterator: Iterator[((Int, K), C)]): Unit = {
+  private def spillToPartitionFiles(iterator: WritablePartitionedIterator): Unit = {
     assert(bypassMergeSort)
 
     // Create our file writers if we haven't done so yet
@@ -385,11 +372,8 @@ private[spark] class ExternalSorter[K, V, C](
 
     // No need to sort stuff, just write each element out
     while (iterator.hasNext) {
-      val elem = iterator.next()
-      val partitionId = elem._1._1
-      val key = elem._1._2
-      val value = elem._2
-      partitionWriters(partitionId).write((key, value))
+      val partitionId = iterator.nextPartition()
+      iterator.writeNext(partitionWriters(partitionId))
     }
   }
 
@@ -618,8 +602,8 @@ private[spark] class ExternalSorter[K, V, C](
       if (finished || deserializeStream == null) {
         return null
       }
-      val k = deserializeStream.readObject().asInstanceOf[K]
-      val c = deserializeStream.readObject().asInstanceOf[C]
+      val k = deserializeStream.readKey().asInstanceOf[K]
+      val c = deserializeStream.readValue().asInstanceOf[C]
       lastPartitionId = partitionId
       // Start reading the next batch if we're done with this one
       indexInBatch += 1
@@ -695,27 +679,27 @@ private[spark] class ExternalSorter[K, V, C](
    */
    def partitionedIterator: Iterator[(Int, Iterator[Product2[K, C]])] = {
     val usingMap = aggregator.isDefined
-    val collection: SizeTrackingPairCollection[(Int, K), C] = if (usingMap) map else buffer
+    val collection: WritablePartitionedPairCollection[K, C] = if (usingMap) map else buffer
     if (spills.isEmpty && partitionWriters == null) {
       // Special case: if we have only in-memory data, we don't need to merge streams, and perhaps
       // we don't even need to sort by anything other than partition ID
       if (!ordering.isDefined) {
         // The user hasn't requested sorted keys, so only sort by partition ID, not key
-        groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+        groupByPartition(collection.partitionedDestructiveSortedIterator(None))
       } else {
         // We do need to sort by both partition ID and key
-        groupByPartition(collection.destructiveSortedIterator(partitionKeyComparator))
+        groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
       }
     } else if (bypassMergeSort) {
       // Read data from each partition file and merge it together with the data in memory;
       // note that there's no ordering or aggregator in this case -- we just partition objects
-      val collIter = groupByPartition(collection.destructiveSortedIterator(partitionComparator))
+      val collIter = groupByPartition(collection.partitionedDestructiveSortedIterator(None))
       collIter.map { case (partitionId, values) =>
         (partitionId, values ++ readPartitionFile(partitionWriters(partitionId)))
       }
     } else {
       // Merge spilled and in-memory data
-      merge(spills, collection.destructiveSortedIterator(partitionKeyComparator))
+      merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
     }
   }
 
@@ -762,15 +746,29 @@ private[spark] class ExternalSorter[K, V, C](
         context.taskMetrics.shuffleWriteMetrics.foreach(
           _.incShuffleWriteTime(System.nanoTime - writeStartTime))
       }
+    } else if (spills.isEmpty && partitionWriters == null) {
+      // Case where we only have in-memory data
+      val collection = if (aggregator.isDefined) map else buffer
+      val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
+      while (it.hasNext) {
+        val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
+          context.taskMetrics.shuffleWriteMetrics.get)
+        val partitionId = it.nextPartition()
+        while (it.hasNext && it.nextPartition() == partitionId) {
+          it.writeNext(writer)
+        }
+        writer.commitAndClose()
+        val segment = writer.fileSegment()
+        lengths(partitionId) = segment.length
+      }
     } else {
-      // Either we're not bypassing merge-sort or we have only in-memory data; get an iterator by
-      // partition and just write everything directly.
+      // Not bypassing merge-sort; get an iterator by partition and just write everything directly.
       for ((id, elements) <- this.partitionedIterator) {
         if (elements.hasNext) {
           val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize,
             context.taskMetrics.shuffleWriteMetrics.get)
           for (elem <- elements) {
-            writer.write(elem)
+            writer.write(elem._1, elem._2)
           }
           writer.commitAndClose()
           val segment = writer.fileSegment()
@@ -799,7 +797,7 @@ private[spark] class ExternalSorter[K, V, C](
     if (writer.isOpen) {
       writer.commitAndClose()
     }
-    blockManager.diskStore.getValues(writer.blockId, ser).get.asInstanceOf[Iterator[Product2[K, C]]]
+    new PairIterator[K, C](blockManager.diskStore.getValues(writer.blockId, ser).get)
   }
 
   def stop(): Unit = {
@@ -829,6 +827,14 @@ private[spark] class ExternalSorter[K, V, C](
     (0 until numPartitions).iterator.map(p => (p, new IteratorForPartition(p, buffered)))
   }
 
+  private def comparator: Option[Comparator[K]] = {
+    if (ordering.isDefined || aggregator.isDefined) {
+      Some(keyComparator)
+    } else {
+      None
+    }
+  }
+
   /**
    * An iterator that reads only the elements for a given partition ID from an underlying buffered
    * stream, assuming this partition is the next one to be read. Used to make it easier to return

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
new file mode 100644
index 0000000..d75959f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PairIterator.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+private[spark] class PairIterator[K, V](iter: Iterator[Any]) extends Iterator[(K, V)] {
+  def hasNext: Boolean = iter.hasNext
+
+  def next(): (K, V) = (iter.next().asInstanceOf[K], iter.next().asInstanceOf[V])
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
new file mode 100644
index 0000000..e2e2f1f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedAppendOnlyMap.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
+/**
+ * Implementation of WritablePartitionedPairCollection that wraps a map in which the keys are tuples
+ * of (partition ID, K)
+ */
+private[spark] class PartitionedAppendOnlyMap[K, V]
+  extends SizeTrackingAppendOnlyMap[(Int, K), V] with WritablePartitionedPairCollection[K, V] {
+
+  def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+    : Iterator[((Int, K), V)] = {
+    val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+    destructiveSortedIterator(comparator)
+  }
+
+  def writablePartitionedIterator(): WritablePartitionedIterator = {
+    WritablePartitionedIterator.fromIterator(super.iterator)
+  }
+
+  def insert(partition: Int, key: K, value: V): Unit = {
+    update((partition, key), value)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
new file mode 100644
index 0000000..e8332e1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.WritablePartitionedPairCollection._
+
+/**
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that keeps track
+ * of its estimated size in bytes.
+ */
+private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64)
+  extends WritablePartitionedPairCollection[K, V] with SizeTracker
+{
+  require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
+  require(initialCapacity >= 1, "Invalid initial capacity")
+
+  // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
+  // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
+  private var capacity = initialCapacity
+  private var curSize = 0
+  private var data = new Array[AnyRef](2 * initialCapacity)
+
+  /** Add an element into the buffer */
+  def insert(partition: Int, key: K, value: V): Unit = {
+    if (curSize == capacity) {
+      growArray()
+    }
+    data(2 * curSize) = (partition, key.asInstanceOf[AnyRef])
+    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
+    curSize += 1
+    afterUpdate()
+  }
+
+  /** Double the size of the array because we've reached capacity */
+  private def growArray(): Unit = {
+    if (capacity == (1 << 29)) {
+      // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+      throw new Exception("Can't grow buffer beyond 2^29 elements")
+    }
+    val newCapacity = capacity * 2
+    val newArray = new Array[AnyRef](2 * newCapacity)
+    System.arraycopy(data, 0, newArray, 0, 2 * capacity)
+    data = newArray
+    capacity = newCapacity
+    resetSamples()
+  }
+
+  /** Iterate through the data in a given order. For this class this is not really destructive. */
+  override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+    : Iterator[((Int, K), V)] = {
+    val comparator = keyComparator.map(partitionKeyComparator).getOrElse(partitionComparator)
+    new Sorter(new KVArraySortDataFormat[(Int, K), AnyRef]).sort(data, 0, curSize, comparator)
+    iterator
+  }
+
+  override def writablePartitionedIterator(): WritablePartitionedIterator = {
+    WritablePartitionedIterator.fromIterator(iterator)
+  }
+
+  private def iterator(): Iterator[((Int, K), V)] = new Iterator[((Int, K), V)] {
+    var pos = 0
+
+    override def hasNext: Boolean = pos < curSize
+
+    override def next(): ((Int, K), V) = {
+      if (!hasNext) {
+        throw new NoSuchElementException
+      }
+      val pair = (data(2 * pos).asInstanceOf[(Int, K)], data(2 * pos + 1).asInstanceOf[V])
+      pos += 1
+      pair
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
new file mode 100644
index 0000000..b5ca0c6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedSerializedPairBuffer.scala
@@ -0,0 +1,254 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.io.InputStream
+import java.nio.IntBuffer
+import java.util.Comparator
+
+import org.apache.spark.SparkEnv
+import org.apache.spark.serializer.{JavaSerializerInstance, SerializerInstance}
+import org.apache.spark.storage.BlockObjectWriter
+import org.apache.spark.util.collection.PartitionedSerializedPairBuffer._
+
+/**
+ * Append-only buffer of key-value pairs, each with a corresponding partition ID, that serializes
+ * its records upon insert and stores them as raw bytes.
+ *
+ * We use two data-structures to store the contents. The serialized records are stored in a
+ * ChainedBuffer that can expand gracefully as records are added. This buffer is accompanied by a
+ * metadata buffer that stores pointers into the data buffer as well as the partition ID of each
+ * record. Each entry in the metadata buffer takes up a fixed amount of space.
+ *
+ * Sorting the collection means swapping entries in the metadata buffer - the record buffer need not
+ * be modified at all. Storing the partition IDs in the metadata buffer means that comparisons can
+ * happen without following any pointers, which should minimize cache misses.
+ *
+ * Currently, only sorting by partition is supported.
+ *
+ * @param metaInitialRecords The initial number of entries in the metadata buffer.
+ * @param kvBlockSize The size of each byte buffer in the ChainedBuffer used to store the records.
+ * @param serializerInstance the serializer used for serializing inserted records.
+ */
+private[spark] class PartitionedSerializedPairBuffer[K, V](
+    metaInitialRecords: Int,
+    kvBlockSize: Int,
+    serializerInstance: SerializerInstance)
+  extends WritablePartitionedPairCollection[K, V] with SizeTracker {
+
+  if (serializerInstance.isInstanceOf[JavaSerializerInstance]) {
+    throw new IllegalArgumentException("PartitionedSerializedPairBuffer does not support" +
+      " Java-serialized objects.")
+  }
+
+  private var metaBuffer = IntBuffer.allocate(metaInitialRecords * RECORD_SIZE)
+
+  private val kvBuffer: ChainedBuffer = new ChainedBuffer(kvBlockSize)
+  private val kvOutputStream = new ChainedBufferOutputStream(kvBuffer)
+  private val kvSerializationStream = serializerInstance.serializeStream(kvOutputStream)
+
+  def insert(partition: Int, key: K, value: V): Unit = {
+    if (metaBuffer.position == metaBuffer.capacity) {
+      growMetaBuffer()
+    }
+
+    val keyStart = kvBuffer.size
+    if (keyStart < 0) {
+      throw new Exception(s"Can't grow buffer beyond ${1 << 31} bytes")
+    }
+    kvSerializationStream.writeObject[Any](key)
+    kvSerializationStream.flush()
+    val valueStart = kvBuffer.size
+    kvSerializationStream.writeObject[Any](value)
+    kvSerializationStream.flush()
+    val valueEnd = kvBuffer.size
+
+    metaBuffer.put(keyStart)
+    metaBuffer.put(valueStart)
+    metaBuffer.put(valueEnd)
+    metaBuffer.put(partition)
+  }
+
+  /** Double the size of the array because we've reached capacity */
+  private def growMetaBuffer(): Unit = {
+    if (metaBuffer.capacity.toLong * 2 > Int.MaxValue) {
+      // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
+      throw new Exception(s"Can't grow buffer beyond ${Int.MaxValue} bytes")
+    }
+    val newMetaBuffer = IntBuffer.allocate(metaBuffer.capacity * 2)
+    newMetaBuffer.put(metaBuffer.array)
+    metaBuffer = newMetaBuffer
+  }
+
+  /** Iterate through the data in a given order. For this class this is not really destructive. */
+  override def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+    : Iterator[((Int, K), V)] = {
+    sort(keyComparator)
+    val is = orderedInputStream
+    val deserStream = serializerInstance.deserializeStream(is)
+    new Iterator[((Int, K), V)] {
+      var metaBufferPos = 0
+      def hasNext: Boolean = metaBufferPos < metaBuffer.position
+      def next(): ((Int, K), V) = {
+        val key = deserStream.readKey[Any]().asInstanceOf[K]
+        val value = deserStream.readValue[Any]().asInstanceOf[V]
+        val partition = metaBuffer.get(metaBufferPos + PARTITION)
+        metaBufferPos += RECORD_SIZE
+        ((partition, key), value)
+      }
+    }
+  }
+
+  override def estimateSize: Long = metaBuffer.capacity * 4 + kvBuffer.capacity
+
+  override def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+    : WritablePartitionedIterator = {
+    sort(keyComparator)
+    writablePartitionedIterator
+  }
+
+  override def writablePartitionedIterator(): WritablePartitionedIterator = {
+    new WritablePartitionedIterator {
+      // current position in the meta buffer in ints
+      var pos = 0
+
+      def writeNext(writer: BlockObjectWriter): Unit = {
+        val keyStart = metaBuffer.get(pos + KEY_START)
+        val valueEnd = metaBuffer.get(pos + VAL_END)
+        pos += RECORD_SIZE
+        kvBuffer.read(keyStart, writer, valueEnd - keyStart)
+        writer.recordWritten()
+      }
+      def nextPartition(): Int = metaBuffer.get(pos + PARTITION)
+      def hasNext(): Boolean = pos < metaBuffer.position
+    }
+  }
+
+  // Visible for testing
+  def orderedInputStream: OrderedInputStream = {
+    new OrderedInputStream(metaBuffer, kvBuffer)
+  }
+
+  private def sort(keyComparator: Option[Comparator[K]]): Unit = {
+    val comparator = if (keyComparator.isEmpty) {
+      new Comparator[Int]() {
+        def compare(partition1: Int, partition2: Int): Int = {
+          partition1 - partition2
+        }
+      }
+    } else {
+      throw new UnsupportedOperationException()
+    }
+
+    val sorter = new Sorter(new SerializedSortDataFormat)
+    sorter.sort(metaBuffer, 0, metaBuffer.position / RECORD_SIZE, comparator)
+  }
+}
+
+private[spark] class OrderedInputStream(metaBuffer: IntBuffer, kvBuffer: ChainedBuffer)
+    extends InputStream {
+
+  private var metaBufferPos = 0
+  private var kvBufferPos =
+    if (metaBuffer.position > 0) metaBuffer.get(metaBufferPos + KEY_START) else 0
+
+  override def read(bytes: Array[Byte]): Int = read(bytes, 0, bytes.length)
+
+  override def read(bytes: Array[Byte], offs: Int, len: Int): Int = {
+    if (metaBufferPos >= metaBuffer.position) {
+      return -1
+    }
+    val bytesRemainingInRecord = metaBuffer.get(metaBufferPos + VAL_END) - kvBufferPos
+    val toRead = math.min(bytesRemainingInRecord, len)
+    kvBuffer.read(kvBufferPos, bytes, offs, toRead)
+    if (toRead == bytesRemainingInRecord) {
+      metaBufferPos += RECORD_SIZE
+      if (metaBufferPos < metaBuffer.position) {
+        kvBufferPos = metaBuffer.get(metaBufferPos + KEY_START)
+      }
+    } else {
+      kvBufferPos += toRead
+    }
+    toRead
+  }
+
+  override def read(): Int = {
+    throw new UnsupportedOperationException()
+  }
+}
+
+private[spark] class SerializedSortDataFormat extends SortDataFormat[Int, IntBuffer] {
+
+  private val META_BUFFER_TMP = new Array[Int](RECORD_SIZE)
+
+  /** Return the sort key for the element at the given index. */
+  override protected def getKey(metaBuffer: IntBuffer, pos: Int): Int = {
+    metaBuffer.get(pos * RECORD_SIZE + PARTITION)
+  }
+
+  /** Swap two elements. */
+  override def swap(metaBuffer: IntBuffer, pos0: Int, pos1: Int): Unit = {
+    val iOff = pos0 * RECORD_SIZE
+    val jOff = pos1 * RECORD_SIZE
+    System.arraycopy(metaBuffer.array, iOff, META_BUFFER_TMP, 0, RECORD_SIZE)
+    System.arraycopy(metaBuffer.array, jOff, metaBuffer.array, iOff, RECORD_SIZE)
+    System.arraycopy(META_BUFFER_TMP, 0, metaBuffer.array, jOff, RECORD_SIZE)
+  }
+
+  /** Copy a single element from src(srcPos) to dst(dstPos). */
+  override def copyElement(
+      src: IntBuffer,
+      srcPos: Int,
+      dst: IntBuffer,
+      dstPos: Int): Unit = {
+    val srcOff = srcPos * RECORD_SIZE
+    val dstOff = dstPos * RECORD_SIZE
+    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE)
+  }
+
+  /**
+   * Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
+   * Overlapping ranges are allowed.
+   */
+  override def copyRange(
+      src: IntBuffer,
+      srcPos: Int,
+      dst: IntBuffer,
+      dstPos: Int,
+      length: Int): Unit = {
+    val srcOff = srcPos * RECORD_SIZE
+    val dstOff = dstPos * RECORD_SIZE
+    System.arraycopy(src.array, srcOff, dst.array, dstOff, RECORD_SIZE * length)
+  }
+
+  /**
+   * Allocates a Buffer that can hold up to 'length' elements.
+   * All elements of the buffer should be considered invalid until data is explicitly copied in.
+   */
+  override def allocate(length: Int): IntBuffer = {
+    IntBuffer.allocate(length * RECORD_SIZE)
+  }
+}
+
+private[spark] object PartitionedSerializedPairBuffer {
+  val KEY_START = 0
+  val VAL_START = 1
+  val VAL_END = 2
+  val PARTITION = 3
+  val RECORD_SIZE = Seq(KEY_START, VAL_START, VAL_END, PARTITION).size // num ints of metadata
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
index eb4de41..722f78b 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala
@@ -21,7 +21,7 @@ package org.apache.spark.util.collection
  * An append-only map that keeps track of its estimated size in bytes.
  */
 private[spark] class SizeTrackingAppendOnlyMap[K, V]
-  extends AppendOnlyMap[K, V] with SizeTracker with SizeTrackingPairCollection[K, V]
+  extends AppendOnlyMap[K, V] with SizeTracker
 {
   override def update(key: K, value: V): Unit = {
     super.update(key, value)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
deleted file mode 100644
index 9e9c16c..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairBuffer.scala
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.util.Comparator
-
-/**
- * Append-only buffer of key-value pairs that keeps track of its estimated size in bytes.
- */
-private[spark] class SizeTrackingPairBuffer[K, V](initialCapacity: Int = 64)
-  extends SizeTracker with SizeTrackingPairCollection[K, V]
-{
-  require(initialCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-  require(initialCapacity >= 1, "Invalid initial capacity")
-
-  // Basic growable array data structure. We use a single array of AnyRef to hold both the keys
-  // and the values, so that we can sort them efficiently with KVArraySortDataFormat.
-  private var capacity = initialCapacity
-  private var curSize = 0
-  private var data = new Array[AnyRef](2 * initialCapacity)
-
-  /** Add an element into the buffer */
-  def insert(key: K, value: V): Unit = {
-    if (curSize == capacity) {
-      growArray()
-    }
-    data(2 * curSize) = key.asInstanceOf[AnyRef]
-    data(2 * curSize + 1) = value.asInstanceOf[AnyRef]
-    curSize += 1
-    afterUpdate()
-  }
-
-  /** Total number of elements in buffer */
-  override def size: Int = curSize
-
-  /** Iterate over the elements of the buffer */
-  override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] {
-    var pos = 0
-
-    override def hasNext: Boolean = pos < curSize
-
-    override def next(): (K, V) = {
-      if (!hasNext) {
-        throw new NoSuchElementException
-      }
-      val pair = (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V])
-      pos += 1
-      pair
-    }
-  }
-
-  /** Double the size of the array because we've reached capacity */
-  private def growArray(): Unit = {
-    if (capacity == (1 << 29)) {
-      // Doubling the capacity would create an array bigger than Int.MaxValue, so don't
-      throw new Exception("Can't grow buffer beyond 2^29 elements")
-    }
-    val newCapacity = capacity * 2
-    val newArray = new Array[AnyRef](2 * newCapacity)
-    System.arraycopy(data, 0, newArray, 0, 2 * capacity)
-    data = newArray
-    capacity = newCapacity
-    resetSamples()
-  }
-
-  /** Iterate through the data in a given order. For this class this is not really destructive. */
-  override def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)] = {
-    new Sorter(new KVArraySortDataFormat[K, AnyRef]).sort(data, 0, curSize, keyComparator)
-    iterator
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
deleted file mode 100644
index faa4e2b..0000000
--- a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingPairCollection.scala
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection
-
-import java.util.Comparator
-
-/**
- * A common interface for our size-tracking collections of key-value pairs, which are used in
- * external operations. These all support estimating the size and obtaining a memory-efficient
- * sorted iterator.
- */
-// TODO: should extend Iterable[Product2[K, V]] instead of (K, V)
-private[spark] trait SizeTrackingPairCollection[K, V] extends Iterable[(K, V)] {
-  /** Estimate the collection's current memory usage in bytes. */
-  def estimateSize(): Long
-
-  /** Iterate through the data in a given key order. This may destroy the underlying collection. */
-  def destructiveSortedIterator(keyComparator: Comparator[K]): Iterator[(K, V)]
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
new file mode 100644
index 0000000..f26d161
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/WritablePartitionedPairCollection.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+import org.apache.spark.storage.BlockObjectWriter
+
+/**
+ * A common interface for size-tracking collections of key-value pairs that
+ * - Have an associated partition for each key-value pair.
+ * - Support a memory-efficient sorted iterator
+ * - Support a WritablePartitionedIterator for writing the contents directly as bytes.
+ */
+private[spark] trait WritablePartitionedPairCollection[K, V] {
+  /**
+   * Insert a key-value pair with a partition into the collection
+   */
+  def insert(partition: Int, key: K, value: V): Unit
+
+  /**
+   * Iterate through the data in order of partition ID and then the given comparator. This may
+   * destroy the underlying collection.
+   */
+  def partitionedDestructiveSortedIterator(keyComparator: Option[Comparator[K]])
+    : Iterator[((Int, K), V)]
+
+  /**
+   * Iterate through the data and write out the elements instead of returning them. Records are
+   * returned in order of their partition ID and then the given comparator.
+   * This may destroy the underlying collection.
+   */
+  def destructiveSortedWritablePartitionedIterator(keyComparator: Option[Comparator[K]])
+    : WritablePartitionedIterator = {
+    WritablePartitionedIterator.fromIterator(partitionedDestructiveSortedIterator(keyComparator))
+  }
+
+  /**
+   * Iterate through the data and write out the elements instead of returning them.
+   */
+  def writablePartitionedIterator(): WritablePartitionedIterator
+}
+
+private[spark] object WritablePartitionedPairCollection {
+  /**
+   * A comparator for (Int, K) pairs that orders them by only their partition ID.
+   */
+  def partitionComparator[K]: Comparator[(Int, K)] = new Comparator[(Int, K)] {
+    override def compare(a: (Int, K), b: (Int, K)): Int = {
+      a._1 - b._1
+    }
+  }
+
+  /**
+   * A comparator for (Int, K) pairs that orders them both by their partition ID and a key ordering.
+   */
+  def partitionKeyComparator[K](keyComparator: Comparator[K]): Comparator[(Int, K)] = {
+    new Comparator[(Int, K)] {
+      override def compare(a: (Int, K), b: (Int, K)): Int = {
+        val partitionDiff = a._1 - b._1
+        if (partitionDiff != 0) {
+          partitionDiff
+        } else {
+          keyComparator.compare(a._2, b._2)
+        }
+      }
+    }
+  }
+}
+
+/**
+ * Iterator that writes elements to a BlockObjectWriter instead of returning them. Each element
+ * has an associated partition.
+ */
+private[spark] trait WritablePartitionedIterator {
+  def writeNext(writer: BlockObjectWriter): Unit
+
+  def hasNext(): Boolean
+
+  def nextPartition(): Int
+}
+
+private[spark] object WritablePartitionedIterator {
+  def fromIterator(it: Iterator[((Int, _), _)]): WritablePartitionedIterator = {
+    new WritablePartitionedIterator {
+      var cur = if (it.hasNext) it.next() else null
+
+      def writeNext(writer: BlockObjectWriter): Unit = {
+        writer.write(cur._1._2, cur._2)
+        cur = if (it.hasNext) it.next() else null
+      }
+
+      def hasNext(): Boolean = cur != null
+
+      def nextPartition(): Int = cur._1._1
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 1b13559..778a7ee 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -280,6 +280,15 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
     val thrown = intercept[SparkException](ser.serialize(largeObject))
     assert(thrown.getMessage.contains(kryoBufferMaxProperty))
   }
+
+  test("getAutoReset") {
+    val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance]
+    assert(ser.getAutoReset)
+    val conf = new SparkConf().set("spark.kryo.registrator",
+      classOf[RegistratorWithoutAutoReset].getName)
+    val ser2 = new KryoSerializer(conf).newInstance().asInstanceOf[KryoSerializerInstance]
+    assert(!ser2.getAutoReset)
+  }
 }
 
 
@@ -313,4 +322,10 @@ object KryoTest {
       k.register(classOf[java.util.HashMap[_, _]])
     }
   }
+
+  class RegistratorWithoutAutoReset extends KryoRegistrator {
+    override def registerClasses(k: Kryo) {
+      k.setAutoReset(false)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
index 963264c..86fcf44 100644
--- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -24,7 +24,7 @@ import scala.reflect.ClassTag
 
 
 /**
- * A serializer implementation that always return a single element in a deserialization stream.
+ * A serializer implementation that always returns two elements in a deserialization stream.
  */
 class TestSerializer extends Serializer {
   override def newInstance(): TestSerializerInstance = new TestSerializerInstance
@@ -51,7 +51,7 @@ class TestDeserializationStream extends DeserializationStream {
 
   override def readObject[T: ClassTag](): T = {
     count += 1
-    if (count == 2) {
+    if (count == 3) {
       throw new EOFException
     }
     new Object().asInstanceOf[T]

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index 7d76435..84384bb 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -59,8 +59,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
     val shuffle1 = shuffleBlockManager.forMapTask(1, 1, 1, new JavaSerializer(conf),
       new ShuffleWriteMetrics)
     for (writer <- shuffle1.writers) {
-      writer.write("test1")
-      writer.write("test2")
+      writer.write("test1", "value")
+      writer.write("test2", "value")
     }
     for (writer <- shuffle1.writers) {
       writer.commitAndClose()
@@ -73,8 +73,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
       new ShuffleWriteMetrics)
 
     for (writer <- shuffle2.writers) {
-      writer.write("test3")
-      writer.write("test4")
+      writer.write("test3", "value")
+      writer.write("test4", "vlue")
     }
     for (writer <- shuffle2.writers) {
       writer.commitAndClose()
@@ -91,8 +91,8 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
     val shuffle3 = shuffleBlockManager.forMapTask(1, 3, 1, new JavaSerializer(testConf),
       new ShuffleWriteMetrics)
     for (writer <- shuffle3.writers) {
-      writer.write("test3")
-      writer.write("test4")
+      writer.write("test3", "value")
+      writer.write("test4", "value")
     }
     for (writer <- shuffle3.writers) {
       writer.commitAndClose()

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
index 003a728..43ef469 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -32,7 +32,7 @@ class BlockObjectWriterSuite extends FunSuite {
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
 
-    writer.write(Long.box(20))
+    writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
     assert(writeMetrics.shuffleRecordsWritten === 1)
     // Metrics don't update on every write
@@ -40,7 +40,7 @@ class BlockObjectWriterSuite extends FunSuite {
     // After 32 writes, metrics should update
     for (i <- 0 until 32) {
       writer.flush()
-      writer.write(Long.box(i))
+      writer.write(Long.box(i), Long.box(i))
     }
     assert(writeMetrics.shuffleBytesWritten > 0)
     assert(writeMetrics.shuffleRecordsWritten === 33)
@@ -54,7 +54,7 @@ class BlockObjectWriterSuite extends FunSuite {
     val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
       new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics)
 
-    writer.write(Long.box(20))
+    writer.write(Long.box(20), Long.box(30))
     // Record metrics update on every write
     assert(writeMetrics.shuffleRecordsWritten === 1)
     // Metrics don't update on every write
@@ -62,7 +62,7 @@ class BlockObjectWriterSuite extends FunSuite {
     // After 32 writes, metrics should update
     for (i <- 0 until 32) {
       writer.flush()
-      writer.write(Long.box(i))
+      writer.write(Long.box(i), Long.box(i))
     }
     assert(writeMetrics.shuffleBytesWritten > 0)
     assert(writeMetrics.shuffleRecordsWritten === 33)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a2b15ce/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
new file mode 100644
index 0000000..c0c38cd
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/collection/ChainedBufferSuite.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.nio.ByteBuffer
+
+import org.scalatest.FunSuite
+import org.scalatest.Matchers._
+
+class ChainedBufferSuite extends FunSuite {
+  test("write and read at start") {
+    // write from start of source array
+    val buffer = new ChainedBuffer(8)
+    buffer.capacity should be (0)
+    verifyWriteAndRead(buffer, 0, 0, 0, 4)
+    buffer.capacity should be (8)
+
+    // write from middle of source array
+    verifyWriteAndRead(buffer, 0, 5, 0, 4)
+    buffer.capacity should be (8)
+
+    // read to middle of target array
+    verifyWriteAndRead(buffer, 0, 0, 5, 4)
+    buffer.capacity should be (8)
+
+    // write up to border
+    verifyWriteAndRead(buffer, 0, 0, 0, 8)
+    buffer.capacity should be (8)
+
+    // expand into second buffer
+    verifyWriteAndRead(buffer, 0, 0, 0, 12)
+    buffer.capacity should be (16)
+
+    // expand into multiple buffers
+    verifyWriteAndRead(buffer, 0, 0, 0, 28)
+    buffer.capacity should be (32)
+  }
+
+  test("write and read at middle") {
+    val buffer = new ChainedBuffer(8)
+
+    // fill to a middle point
+    verifyWriteAndRead(buffer, 0, 0, 0, 3)
+
+    // write from start of source array
+    verifyWriteAndRead(buffer, 3, 0, 0, 4)
+    buffer.capacity should be (8)
+
+    // write from middle of source array
+    verifyWriteAndRead(buffer, 3, 5, 0, 4)
+    buffer.capacity should be (8)
+
+    // read to middle of target array
+    verifyWriteAndRead(buffer, 3, 0, 5, 4)
+    buffer.capacity should be (8)
+
+    // write up to border
+    verifyWriteAndRead(buffer, 3, 0, 0, 5)
+    buffer.capacity should be (8)
+
+    // expand into second buffer
+    verifyWriteAndRead(buffer, 3, 0, 0, 12)
+    buffer.capacity should be (16)
+
+    // expand into multiple buffers
+    verifyWriteAndRead(buffer, 3, 0, 0, 28)
+    buffer.capacity should be (32)
+  }
+
+  test("write and read at later buffer") {
+    val buffer = new ChainedBuffer(8)
+
+    // fill to a middle point
+    verifyWriteAndRead(buffer, 0, 0, 0, 11)
+
+    // write from start of source array
+    verifyWriteAndRead(buffer, 11, 0, 0, 4)
+    buffer.capacity should be (16)
+
+    // write from middle of source array
+    verifyWriteAndRead(buffer, 11, 5, 0, 4)
+    buffer.capacity should be (16)
+
+    // read to middle of target array
+    verifyWriteAndRead(buffer, 11, 0, 5, 4)
+    buffer.capacity should be (16)
+
+    // write up to border
+    verifyWriteAndRead(buffer, 11, 0, 0, 5)
+    buffer.capacity should be (16)
+
+    // expand into second buffer
+    verifyWriteAndRead(buffer, 11, 0, 0, 12)
+    buffer.capacity should be (24)
+
+    // expand into multiple buffers
+    verifyWriteAndRead(buffer, 11, 0, 0, 28)
+    buffer.capacity should be (40)
+  }
+
+
+  // Used to make sure we're writing different bytes each time
+  var rangeStart = 0
+
+  /**
+   * @param buffer The buffer to write to and read from.
+   * @param offsetInBuffer The offset to write to in the buffer.
+   * @param offsetInSource The offset in the array that the bytes are written from.
+   * @param offsetInTarget The offset in the array to read the bytes into.
+   * @param length The number of bytes to read and write
+   */
+  def verifyWriteAndRead(
+      buffer: ChainedBuffer,
+      offsetInBuffer: Int,
+      offsetInSource: Int,
+      offsetInTarget: Int,
+      length: Int): Unit = {
+    val source = new Array[Byte](offsetInSource + length)
+    (rangeStart until rangeStart + length).map(_.toByte).copyToArray(source, offsetInSource)
+    buffer.write(offsetInBuffer, source, offsetInSource, length)
+    val target = new Array[Byte](offsetInTarget + length)
+    buffer.read(offsetInBuffer, target, offsetInTarget, length)
+    ByteBuffer.wrap(source, offsetInSource, length) should be
+      (ByteBuffer.wrap(target, offsetInTarget, length))
+
+    rangeStart += 100
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org