You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by ma...@apache.org on 2014/06/24 05:25:52 UTC

git commit: [SPARK-2124] Move aggregation into shuffle implementations

Repository: spark
Updated Branches:
  refs/heads/master 51c816837 -> 56eb8af18


[SPARK-2124] Move aggregation into shuffle implementations

This PR is a sub-task of SPARK-2044 to move the execution of aggregation into shuffle implementations.

I leave `CoGoupedRDD` and `SubtractedRDD` unchanged because they have their implementations of aggregation. I'm not sure is it suitable to change these two RDDs.

Also I do not move sort related code of `OrderedRDDFunctions` into shuffle, this will be solved in another sub-task.

Author: jerryshao <sa...@intel.com>

Closes #1064 from jerryshao/SPARK-2124 and squashes the following commits:

4a05a40 [jerryshao] Modify according to comments
1f7dcc8 [jerryshao] Style changes
50a2fd6 [jerryshao] Fix test suite issue after moving aggregator to Shuffle reader and writer
1a96190 [jerryshao] Code modification related to the ShuffledRDD
308f635 [jerryshao] initial works of move combiner to ShuffleManager's reader and writer


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56eb8af1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56eb8af1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56eb8af1

Branch: refs/heads/master
Commit: 56eb8af187b19f09810baafb314b21e07cf0a79c
Parents: 51c8168
Author: jerryshao <sa...@intel.com>
Authored: Mon Jun 23 20:25:46 2014 -0700
Committer: Matei Zaharia <ma...@databricks.com>
Committed: Mon Jun 23 20:25:46 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Dependency.scala     |  3 +-
 .../apache/spark/rdd/OrderedRDDFunctions.scala  |  2 +-
 .../org/apache/spark/rdd/PairRDDFunctions.scala | 20 +++--------
 .../main/scala/org/apache/spark/rdd/RDD.scala   |  2 +-
 .../org/apache/spark/rdd/ShuffledRDD.scala      | 37 ++++++++++++++++----
 .../apache/spark/scheduler/ShuffleMapTask.scala |  6 ++--
 .../apache/spark/shuffle/ShuffleWriter.scala    |  4 +--
 .../spark/shuffle/hash/HashShuffleReader.scala  | 20 +++++++++--
 .../spark/shuffle/hash/HashShuffleWriter.scala  | 23 +++++++++---
 .../org/apache/spark/CheckpointSuite.scala      |  2 +-
 .../scala/org/apache/spark/ShuffleSuite.scala   | 22 +++++++-----
 .../scala/org/apache/spark/rdd/RDDSuite.scala   | 14 ++++----
 .../spark/scheduler/SparkListenerSuite.scala    |  2 +-
 .../spark/graphx/impl/MessageToPartition.scala  |  7 ++--
 .../graphx/impl/RoutingTablePartition.scala     |  4 +--
 .../apache/spark/sql/execution/Exchange.scala   |  6 ++--
 .../spark/sql/execution/basicOperators.scala    |  2 +-
 17 files changed, 112 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/Dependency.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index c8c194a..09a6057 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -61,7 +61,8 @@ class ShuffleDependency[K, V, C](
     val partitioner: Partitioner,
     val serializer: Option[Serializer] = None,
     val keyOrdering: Option[Ordering[K]] = None,
-    val aggregator: Option[Aggregator[K, V, C]] = None)
+    val aggregator: Option[Aggregator[K, V, C]] = None,
+    val mapSideCombine: Boolean = false)
   extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
 
   val shuffleId: Int = rdd.context.newShuffleId()

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 6a3f698..f1f4b43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -57,7 +57,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
    */
   def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
     val part = new RangePartitioner(numPartitions, self, ascending)
-    val shuffled = new ShuffledRDD[K, V, P](self, part)
+    val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
     shuffled.mapPartitions(iter => {
       val buf = iter.toArray
       if (ascending) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 443d1c5..2b2b9ae 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -90,21 +90,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
       self.mapPartitionsWithContext((context, iter) => {
         new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
       }, preservesPartitioning = true)
-    } else if (mapSideCombine) {
-      val combined = self.mapPartitionsWithContext((context, iter) => {
-        aggregator.combineValuesByKey(iter, context)
-      }, preservesPartitioning = true)
-      val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
-        .setSerializer(serializer)
-      partitioned.mapPartitionsWithContext((context, iter) => {
-        new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter, context))
-      }, preservesPartitioning = true)
     } else {
-      // Don't apply map-side combiner.
-      val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializer)
-      values.mapPartitionsWithContext((context, iter) => {
-        new InterruptibleIterator(context, aggregator.combineValuesByKey(iter, context))
-      }, preservesPartitioning = true)
+      new ShuffledRDD[K, V, C, (K, C)](self, partitioner)
+        .setSerializer(serializer)
+        .setAggregator(aggregator)
+        .setMapSideCombine(mapSideCombine)
     }
   }
 
@@ -401,7 +391,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
     if (self.partitioner == Some(partitioner)) {
       self
     } else {
-      new ShuffledRDD[K, V, (K, V)](self, partitioner)
+      new ShuffledRDD[K, V, V, (K, V)](self, partitioner)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/rdd/RDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index cebfd10..4e841bc 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -340,7 +340,7 @@ abstract class RDD[T: ClassTag](
 
       // include a shuffle step so that our upstream tasks are still distributed
       new CoalescedRDD(
-        new ShuffledRDD[Int, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
+        new ShuffledRDD[Int, T, T, (Int, T)](mapPartitionsWithIndex(distributePartition),
         new HashPartitioner(numPartitions)),
         numPartitions).values
     } else {

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index bb108ef..bf02f68 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, SparkEnv, TaskContext}
+import org.apache.spark._
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.serializer.Serializer
 
@@ -35,23 +35,48 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
  * @param part the partitioner used to partition the RDD
  * @tparam K the key class.
  * @tparam V the value class.
+ * @tparam C the combiner class.
  */
 @DeveloperApi
-class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
-    @transient var prev: RDD[P],
+class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
+    @transient var prev: RDD[_ <: Product2[K, V]],
     part: Partitioner)
   extends RDD[P](prev.context, Nil) {
 
   private var serializer: Option[Serializer] = None
 
+  private var keyOrdering: Option[Ordering[K]] = None
+
+  private var aggregator: Option[Aggregator[K, V, C]] = None
+
+  private var mapSideCombine: Boolean = false
+
   /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
-  def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
+  def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
     this.serializer = Option(serializer)
     this
   }
 
+  /** Set key ordering for RDD's shuffle. */
+  def setKeyOrdering(keyOrdering: Ordering[K]): ShuffledRDD[K, V, C, P] = {
+    this.keyOrdering = Option(keyOrdering)
+    this
+  }
+
+  /** Set aggregator for RDD's shuffle. */
+  def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
+    this.aggregator = Option(aggregator)
+    this
+  }
+
+  /** Set mapSideCombine flag for RDD's shuffle. */
+  def setMapSideCombine(mapSideCombine: Boolean): ShuffledRDD[K, V, C, P] = {
+    this.mapSideCombine = mapSideCombine
+    this
+  }
+
   override def getDependencies: Seq[Dependency[_]] = {
-    List(new ShuffleDependency(prev, part, serializer))
+    List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
   }
 
   override val partitioner = Some(part)
@@ -61,7 +86,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
   }
 
   override def compute(split: Partition, context: TaskContext): Iterator[P] = {
-    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
+    val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
     SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
       .read()
       .asInstanceOf[Iterator[P]]

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 859cdc5..fdaf1de 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -144,10 +144,8 @@ private[spark] class ShuffleMapTask(
     try {
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
-      for (elem <- rdd.iterator(split, context)) {
-        writer.write(elem.asInstanceOf[Product2[Any, Any]])
-      }
-      writer.stop(success = true).get
+      writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+      return writer.stop(success = true).get
     } catch {
       case e: Exception =>
         if (writer != null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
index ead3ebd..b934480 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -23,8 +23,8 @@ import org.apache.spark.scheduler.MapStatus
  * Obtained inside a map task to write out records to the shuffle system.
  */
 private[spark] trait ShuffleWriter[K, V] {
-  /** Write a record to this task's output */
-  def write(record: Product2[K, V]): Unit
+  /** Write a bunch of records to this task's output */
+  def write(records: Iterator[_ <: Product2[K, V]]): Unit
 
   /** Close this writer, passing along whether the map completed */
   def stop(success: Boolean): Option[MapStatus]

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index f6a7903..d45258c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -17,9 +17,9 @@
 
 package org.apache.spark.shuffle.hash
 
+import org.apache.spark.{InterruptibleIterator, TaskContext}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
-import org.apache.spark.TaskContext
 
 class HashShuffleReader[K, C](
     handle: BaseShuffleHandle[K, _, C],
@@ -31,10 +31,24 @@ class HashShuffleReader[K, C](
   require(endPartition == startPartition + 1,
     "Hash shuffle currently only supports fetching one partition")
 
+  private val dep = handle.dependency
+
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
-    BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
-      Serializer.getSerializer(handle.dependency.serializer))
+    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
+      Serializer.getSerializer(dep.serializer))
+
+    if (dep.aggregator.isDefined) {
+      if (dep.mapSideCombine) {
+        new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
+      } else {
+        new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
+      }
+    } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
+      throw new IllegalStateException("Aggregator is empty for map-side combine")
+    } else {
+      iter
+    }
   }
 
   /** Close this reader */

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 4c67490..9b78228 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -40,11 +40,24 @@ class HashShuffleWriter[K, V](
   private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
   private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
 
-  /** Write a record to this task's output */
-  override def write(record: Product2[K, V]): Unit = {
-    val pair = record.asInstanceOf[Product2[Any, Any]]
-    val bucketId = dep.partitioner.getPartition(pair._1)
-    shuffle.writers(bucketId).write(pair)
+  /** Write a bunch of records to this task's output */
+  override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
+    val iter = if (dep.aggregator.isDefined) {
+      if (dep.mapSideCombine) {
+        dep.aggregator.get.combineValuesByKey(records, context)
+      } else {
+        records
+      }
+    } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
+      throw new IllegalStateException("Aggregator is empty for map-side combine")
+    } else {
+      records
+    }
+
+    for (elem <- iter) {
+      val bucketId = dep.partitioner.getPartition(elem._1)
+      shuffle.writers(bucketId).write(elem)
+    }
   }
 
   /** Close this writer, passing along whether the map completed */

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f64f3c9..fc00458 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
   test("ShuffledRDD") {
     testRDD(rdd => {
       // Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
-      new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
+      new ShuffledRDD[Int, Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
     })
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 47112ce..b40fee7 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -56,8 +56,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     }
     // If the Kryo serializer is not used correctly, the shuffle would fail because the
     // default Java serializer cannot handle the non serializable class.
-    val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
-      b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
+    val c = new ShuffledRDD[Int,
+      NonJavaSerializableClass,
+      NonJavaSerializableClass,
+      (Int, NonJavaSerializableClass)](b, new HashPartitioner(NUM_BLOCKS))
+    c.setSerializer(new KryoSerializer(conf))
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
 
     assert(c.count === 10)
@@ -78,8 +81,11 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     }
     // If the Kryo serializer is not used correctly, the shuffle would fail because the
     // default Java serializer cannot handle the non serializable class.
-    val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
-      b, new HashPartitioner(3)).setSerializer(new KryoSerializer(conf))
+    val c = new ShuffledRDD[Int,
+      NonJavaSerializableClass,
+      NonJavaSerializableClass,
+      (Int, NonJavaSerializableClass)](b, new HashPartitioner(3))
+    c.setSerializer(new KryoSerializer(conf))
     assert(c.count === 10)
   }
 
@@ -94,7 +100,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
 
     // NOTE: The default Java serializer doesn't create zero-sized blocks.
     //       So, use Kryo
-    val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
+    val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
       .setSerializer(new KryoSerializer(conf))
 
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
@@ -120,7 +126,7 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     val b = a.map(x => (x, x*2))
 
     // NOTE: The default Java serializer should create zero-sized blocks
-    val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
+    val c = new ShuffledRDD[Int, Int, Int, (Int, Int)](b, new HashPartitioner(10))
 
     val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
     assert(c.count === 4)
@@ -141,8 +147,8 @@ class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
     def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
     val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
     val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
-    val results = new ShuffledRDD[Int, Int, MutablePair[Int, Int]](pairs, new HashPartitioner(2))
-      .collect()
+    val results = new ShuffledRDD[Int, Int, Int, MutablePair[Int, Int]](pairs,
+      new HashPartitioner(2)).collect()
 
     data.foreach { pair => results should contain (pair) }
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 0e5625b..0f9cbe2 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -276,7 +276,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     // we can optionally shuffle to keep the upstream parallel
     val coalesced5 = data.coalesce(1, shuffle = true)
     val isEquals = coalesced5.dependencies.head.rdd.dependencies.head.rdd.
-      asInstanceOf[ShuffledRDD[_, _, _]] != null
+      asInstanceOf[ShuffledRDD[_, _, _, _]] != null
     assert(isEquals)
 
     // when shuffling, we can increase the number of partitions
@@ -509,7 +509,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
   test("takeSample") {
     val n = 1000000
     val data = sc.parallelize(1 to n, 2)
-    
+
     for (num <- List(5, 20, 100)) {
       val sample = data.takeSample(withReplacement=false, num=num)
       assert(sample.size === num)        // Got exactly num elements
@@ -704,11 +704,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(ancestors3.count(_.isInstanceOf[MappedRDD[_, _]]) === 2)
 
     // Any ancestors before the shuffle are not considered
-    assert(ancestors4.size === 1)
-    assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
-    assert(ancestors5.size === 4)
-    assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _]]) === 1)
-    assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 1)
+    assert(ancestors4.size === 0)
+    assert(ancestors4.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 0)
+    assert(ancestors5.size === 3)
+    assert(ancestors5.count(_.isInstanceOf[ShuffledRDD[_, _, _, _]]) === 1)
+    assert(ancestors5.count(_.isInstanceOf[MapPartitionsRDD[_, _]]) === 0)
     assert(ancestors5.count(_.isInstanceOf[MappedValuesRDD[_, _, _]]) === 2)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index abd7b22..6df0a08 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -181,7 +181,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
     assert(sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
     listener.stageInfos.size should be {2} // Shuffle map stage + result stage
     val stageInfo3 = listener.stageInfos.keys.find(_.stageId == 2).get
-    stageInfo3.rddInfos.size should be {2} // ShuffledRDD, MapPartitionsRDD
+    stageInfo3.rddInfos.size should be {1} // ShuffledRDD
     stageInfo3.rddInfos.forall(_.numPartitions == 4) should be {true}
     stageInfo3.rddInfos.exists(_.name == "Trois") should be {true}
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
index 1c6d7e5..d85afa4 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/MessageToPartition.scala
@@ -62,7 +62,8 @@ class MessageToPartition[@specialized(Int, Long, Double, Char, Boolean/* , AnyRe
 private[graphx]
 class VertexBroadcastMsgRDDFunctions[T: ClassTag](self: RDD[VertexBroadcastMsg[T]]) {
   def partitionBy(partitioner: Partitioner): RDD[VertexBroadcastMsg[T]] = {
-    val rdd = new ShuffledRDD[PartitionID, (VertexId, T), VertexBroadcastMsg[T]](self, partitioner)
+    val rdd = new ShuffledRDD[PartitionID, (VertexId, T), (VertexId, T), VertexBroadcastMsg[T]](
+      self, partitioner)
 
     // Set a custom serializer if the data is of int or double type.
     if (classTag[T] == ClassTag.Int) {
@@ -84,7 +85,7 @@ class MsgRDDFunctions[T: ClassTag](self: RDD[MessageToPartition[T]]) {
    * Return a copy of the RDD partitioned using the specified partitioner.
    */
   def partitionBy(partitioner: Partitioner): RDD[MessageToPartition[T]] = {
-    new ShuffledRDD[PartitionID, T, MessageToPartition[T]](self, partitioner)
+    new ShuffledRDD[PartitionID, T, T, MessageToPartition[T]](self, partitioner)
   }
 
 }
@@ -103,7 +104,7 @@ object MsgRDDFunctions {
 private[graphx]
 class VertexRDDFunctions[VD: ClassTag](self: RDD[(VertexId, VD)]) {
   def copartitionWithVertices(partitioner: Partitioner): RDD[(VertexId, VD)] = {
-    val rdd = new ShuffledRDD[VertexId, VD, (VertexId, VD)](self, partitioner)
+    val rdd = new ShuffledRDD[VertexId, VD, VD, (VertexId, VD)](self, partitioner)
 
     // Set a custom serializer if the data is of int or double type.
     if (classTag[VD] == ClassTag.Int) {

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
----------------------------------------------------------------------
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
index d02e923..3827ac8 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala
@@ -46,8 +46,8 @@ private[graphx]
 class RoutingTableMessageRDDFunctions(self: RDD[RoutingTableMessage]) {
   /** Copartition an `RDD[RoutingTableMessage]` with the vertex RDD with the given `partitioner`. */
   def copartitionWithVertices(partitioner: Partitioner): RDD[RoutingTableMessage] = {
-    new ShuffledRDD[VertexId, (PartitionID, Byte), RoutingTableMessage](self, partitioner)
-      .setSerializer(new RoutingTableMessageSerializer)
+    new ShuffledRDD[VertexId, (PartitionID, Byte), (PartitionID, Byte), RoutingTableMessage](
+      self, partitioner).setSerializer(new RoutingTableMessageSerializer)
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index f46fa05..00010ef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -47,7 +47,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(r => mutablePair.update(hashExpressions(r), r))
         }
         val part = new HashPartitioner(numPartitions)
-        val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part)
+        val shuffled = new ShuffledRDD[Row, Row, Row, MutablePair[Row, Row]](rdd, part)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
         shuffled.map(_._2)
 
@@ -60,7 +60,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(row => mutablePair.update(row, null))
         }
         val part = new RangePartitioner(numPartitions, rdd, ascending = true)
-        val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part)
+        val shuffled = new ShuffledRDD[Row, Null, Null, MutablePair[Row, Null]](rdd, part)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
 
         shuffled.map(_._1)
@@ -71,7 +71,7 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
           iter.map(r => mutablePair.update(null, r))
         }
         val partitioner = new HashPartitioner(1)
-        val shuffled = new ShuffledRDD[Null, Row, MutablePair[Null, Row]](rdd, partitioner)
+        val shuffled = new ShuffledRDD[Null, Row, Row, MutablePair[Null, Row]](rdd, partitioner)
         shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
         shuffled.map(_._2)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/56eb8af1/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 18f4a58..b40d4e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -105,7 +105,7 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext
       iter.take(limit).map(row => mutablePair.update(false, row))
     }
     val part = new HashPartitioner(1)
-    val shuffled = new ShuffledRDD[Boolean, Row, MutablePair[Boolean, Row]](rdd, part)
+    val shuffled = new ShuffledRDD[Boolean, Row, Row, MutablePair[Boolean, Row]](rdd, part)
     shuffled.setSerializer(new SparkSqlSerializer(new SparkConf(false)))
     shuffled.mapPartitions(_.take(limit).map(_._2))
   }